Пример #1
0
    def configure_optimizers(self):
        # get optimizer
        optimizer_config = self.config["optimizer"]
        params = get_optimizable_parameters(self.model)
        optimizer = torch.optim.Adam(params,
                                     lr=optimizer_config.get("lr", 1e-4))

        # get scheduler
        scheduler_config = self.config["scheduler"]
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=self.config.get("metric_mode", "min"),
            patience=scheduler_config["patience"],
            factor=scheduler_config["factor"],
            min_lr=scheduler_config["min_lr"],
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": scheduler_config.get("metric_to_monitor", "train/loss"),
        }
Пример #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration")
    parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--cache", action="store_true")
    parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-m", "--model", type=str, default="resnet34", help="")
    parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64")
    parser.add_argument(
        "-wbs", "--warmup-batch-size", type=int, default=None, help="Batch Size during training, e.g. -b 64"
    )
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    parser.add_argument(
        "-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement"
    )
    parser.add_argument("-fe", "--freeze-encoder", action="store_true", help="Freeze encoder parameters for N epochs")
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")

    parser.add_argument(
        "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument(
        "--modification-type-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--mask-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--bits-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--mixup", action="store_true")
    parser.add_argument("--cutmix", action="store_true")
    parser.add_argument("--tsa", action="store_true")
    parser.add_argument("--fold", default=None, type=int)
    parser.add_argument("-s", "--scheduler", default=None, type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument(
        "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    assert (
        args.modification_flag_loss or args.modification_type_loss or args.embedding_loss
    ), "At least one of losses must be set"

    modification_flag_loss = args.modification_flag_loss
    modification_type_loss = args.modification_type_loss
    embedding_loss = args.embedding_loss
    feature_maps_loss = args.feature_maps_loss
    mask_loss = args.mask_loss
    bits_loss = args.bits_loss

    freeze_encoder = args.freeze_encoder
    data_dir = args.data_dir
    cache = args.cache
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    model_name: str = args.model
    optimizer_name = args.optimizer
    image_size = (512, 512)
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    verbose = args.verbose
    warmup = args.warmup
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    fold = args.fold
    balance = args.balance
    freeze_bn = args.freeze_bn
    train_batch_size = args.batch_size
    mixup = args.mixup
    cutmix = args.cutmix
    tsa = args.tsa
    fine_tune = args.fine_tune
    obliterate_p = args.obliterate
    negative_image_dir = args.negative_image_dir
    warmup_batch_size = args.warmup_batch_size or args.batch_size

    # Compute batch size for validation
    valid_batch_size = train_batch_size
    run_train = num_epochs > 0

    custom_model_kwargs = {}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    if embedding_loss is not None:
        custom_model_kwargs["need_embedding"] = True

    model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda()
    required_features = model.required_features

    if mask_loss is not None:
        required_features.append(INPUT_TRUE_MODIFICATION_MASK)

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transferring weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    if freeze_bn:
        from pytorch_toolbelt.optimization.functional import freeze_model

        freeze_model(model, freeze_bn=True)
        print("Freezing bn params")

    main_metric = "loss"
    main_metric_minimize = True

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}_fold{fold}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if mixup:
        checkpoint_prefix += "_mixup"

    if cutmix:
        checkpoint_prefix += "_cutmix"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = []

    if show:
        default_callbacks += [ShowPolarBatchesCallback(draw_predictions, metric="loss", minimize=True)]

    # Pretrain/warmup
    if warmup:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation=augmentations,
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=0,
        )

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            feature_maps_loss=feature_maps_loss,
            num_epochs=warmup,
            mixup=mixup,
            cutmix=cutmix,
            tsa=tsa,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=warmup_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=warmup_batch_size, num_workers=num_workers, pin_memory=True)

        if freeze_encoder:
            from pytorch_toolbelt.optimization.functional import freeze_model

            freeze_model(model.encoder, freeze_parameters=True, freeze_bn=None)

        optimizer = get_optimizer(
            "Ranger", get_optimizable_parameters(model), weight_decay=weight_decay, learning_rate=3e-4
        )
        scheduler = None

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout, "(Non-default)" if dropout is not None else "")
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "warmup"),
            num_epochs=warmup,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        del optimizer, loaders, runner, callbacks

        best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_warmup.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        # Restore state of best model
        # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        torch.cuda.empty_cache()
        gc.collect()

    if run_train:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation=augmentations,
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=obliterate_p,
        )

        if negative_image_dir:
            negatives_ds = get_negatives_ds(
                negative_image_dir, fold=fold, features=required_features, max_images=16536
            )
            train_ds = train_ds + negatives_ds
            train_sampler = None  # TODO: Add proper support of sampler
            print("Adding", len(negatives_ds), "negative samples to training set")

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            feature_maps_loss=feature_maps_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            num_epochs=num_epochs,
            mixup=mixup,
            cutmix=cutmix,
            tsa=tsa,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Obliterate (%) :", obliterate_p)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        optimizer = get_optimizer(
            optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        # model training
        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        del optimizer, loaders, runner, callbacks

        best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")

        # Restore state of best model
        clean_checkpoint(best_checkpoint, model_checkpoint)
        # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        torch.cuda.empty_cache()
        gc.collect()

    if fine_tune:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation="light",
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=obliterate_p,
        )

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            feature_maps_loss=feature_maps_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            num_epochs=fine_tune,
            mixup=False,
            cutmix=False,
            tsa=False,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Obliterate (%) :", obliterate_p)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        optimizer = get_optimizer(
            "SGD", get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            "cos", optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        # model training
        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "finetune"),
            num_epochs=fine_tune,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        best_checkpoint = os.path.join(log_dir, "finetune", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_finetune.pth")

        clean_checkpoint(best_checkpoint, model_checkpoint)
        unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        del optimizer, loaders, runner, callbacks
def main():
    parser = argparse.ArgumentParser()

    ###########################################################################################
    # Distributed-training related stuff
    parser.add_argument("--local_rank", type=int, default=0)
    ###########################################################################################

    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument(
        "-dd",
        "--data-dir",
        type=str,
        help="Data directory for INRIA sattelite dataset",
        default=os.environ.get("INRIA_DATA_DIR"),
    )
    parser.add_argument(
        "-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset"
    )
    parser.add_argument("-m", "--model", type=str, default="b6_unet32_s2", help="")
    parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")
    parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion")
    parser.add_argument(
        "-l2",
        "--criterion2",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 2 mask",
    )
    parser.add_argument(
        "-l4",
        "--criterion4",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 4 mask",
    )
    parser.add_argument(
        "-l8",
        "--criterion8",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 8 mask",
    )
    parser.add_argument(
        "-l16",
        "--criterion16",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 16 mask",
    )

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="hard", type=str, help="")
    parser.add_argument("-tm", "--train-mode", default="random", type=str, help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()

    args.is_master = args.local_rank == 0
    args.distributed = False
    fp16 = args.fp16

    if "WORLD_SIZE" in os.environ:
        args.distributed = int(os.environ["WORLD_SIZE"]) > 1
        args.world_size = int(os.environ["WORLD_SIZE"])
        # args.world_size = torch.distributed.get_world_size()

        print("Initializing init_process_group", args.local_rank)

        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        print("Initialized init_process_group", args.local_rank)

    is_master = args.is_master | (not args.distributed)

    if args.distributed:
        distributed_params = {"rank": args.local_rank, "syncbn": True}
        if args.fp16:
            distributed_params["amp"] = True
    else:
        if args.fp16:
            distributed_params = {}
            distributed_params["amp"] = True
        else:
            distributed_params = False

    set_manual_seed(args.seed + args.local_rank)
    catalyst.utils.set_global_seed(args.seed + args.local_rank)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

    data_dir = args.data_dir
    if data_dir is None:
        raise ValueError("--data-dir must be set")

    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    criterions2 = args.criterion2
    criterions4 = args.criterion4
    criterions8 = args.criterion8
    criterions16 = args.criterion16

    verbose = args.verbose
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    custom_model_kwargs = {"full_size_mask": False}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    if any([criterions2, criterions4, criterions8, criterions16]):
        custom_model_kwargs["need_supervision_masks"] = True
        print("Enabling supervision masks")

    model: nn.Module = get_model(model_name, num_classes=16, **custom_model_kwargs).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    main_metric = "jaccard"

    current_time = datetime.now().strftime("%y%m%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    default_callbacks = [
        JaccardMetricPerImage(
            input_key=INPUT_MASK_KEY,
            output_key=OUTPUT_MASK_KEY,
            prefix="jaccard",
            inputs_to_labels=depth2mask,
            outputs_to_labels=decode_depth_mask,
        ),
    ]

    if is_master:

        default_callbacks += [
            BestMetricCheckpointCallback(target_metric="jaccard", target_metric_minimize=False),
            HyperParametersCallback(
                hparam_dict={
                    "model": model_name,
                    "scheduler": scheduler_name,
                    "optimizer": optimizer_name,
                    "augmentations": augmentations,
                    "size": args.size,
                    "weight_decay": weight_decay,
                    "epochs": num_epochs,
                    "dropout": None if dropout is None else float(dropout),
                }
            ),
        ]

        if show:
            visualize_inria_predictions = partial(
                draw_inria_predictions,
                image_key=INPUT_IMAGE_KEY,
                image_id_key=INPUT_IMAGE_ID_KEY,
                targets_key=INPUT_MASK_KEY,
                outputs_key=OUTPUT_MASK_KEY,
                inputs_to_labels=depth2mask,
                outputs_to_labels=decode_depth_mask,
                max_images=16,
            )
            default_callbacks += [
                ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False),
                ShowPolarBatchesCallback(visualize_inria_predictions, metric="loss", minimize=True),
            ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        buildings_only=(train_mode == "tiles"),
        fast=fast,
        need_weight_mask=need_weight_mask,
        make_mask_target_fn=mask_to_ce_target,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds), "samples")

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(
                data_dir, include_masks=False, augmentation=None, image_size=image_size
            )

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir, include_masks=True, augmentation=augmentations, image_size=image_size
            )

            if args.distributed:
                label_sampler = DistributedSampler(unlabeled_label, args.world_size, args.local_rank, shuffle=False)
            else:
                label_sampler = None

            loaders["infer"] = DataLoader(
                unlabeled_label,
                batch_size=batch_size // 2,
                num_workers=num_workers,
                pin_memory=True,
                sampler=label_sampler,
                drop_last=False,
            )

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="infer",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label), "samples")

        valid_sampler = None
        if args.distributed:
            if train_sampler is not None:
                train_sampler = DistributedSamplerWrapper(
                    train_sampler, args.world_size, args.local_rank, shuffle=True
                )
            else:
                train_sampler = DistributedSampler(train_ds, args.world_size, args.local_rank, shuffle=True)
            valid_sampler = DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False)

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(
            valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=valid_sampler
        )

        loss_callbacks, loss_criterions = get_criterions(
            criterions, criterions2, criterions4, criterions8, criterions16
        )
        callbacks += loss_callbacks

        optimizer = get_optimizer(
            optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        log_dir = os.path.join("runs", checkpoint_prefix)

        if is_master:
            os.makedirs(log_dir, exist_ok=False)
            config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
            with open(config_fname, "w") as f:
                train_session_args = vars(args)
                f.write(json.dumps(train_session_args, indent=2))

            print("Train session    :", checkpoint_prefix)
            print("  FP16 mode      :", fp16)
            print("  Fast mode      :", args.fast)
            print("  Train mode     :", train_mode)
            print("  Epochs         :", num_epochs)
            print("  Workers        :", num_workers)
            print("  Data dir       :", data_dir)
            print("  Log dir        :", log_dir)
            print("  Augmentations  :", augmentations)
            print("  Train size     :", "batches", len(loaders["train"]), "dataset", len(train_ds))
            print("  Valid size     :", "batches", len(loaders["valid"]), "dataset", len(valid_ds))
            print("Model            :", model_name)
            print("  Parameters     :", count_parameters(model))
            print("  Image size     :", image_size)
            print("Optimizer        :", optimizer_name)
            print("  Learning rate  :", learning_rate)
            print("  Batch size     :", batch_size)
            print("  Criterion      :", criterions)
            print("  Use weight mask:", need_weight_mask)
            if args.distributed:
                print("Distributed")
                print("  World size     :", args.world_size)
                print("  Local rank     :", args.local_rank)
                print("  Is master      :", args.is_master)

        # model training
        runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda")
        runner.train(
            fp16=distributed_params,
            model=model,
            criterion=loss_criterions,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        if is_master:
            best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")

            model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")
            clean_checkpoint(best_checkpoint, model_checkpoint)

            unpack_checkpoint(torch.load(model_checkpoint), model=model)

            mask = predict(
                model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size
            )
            mask = ((mask > 0) * 255).astype(np.uint8)
            name = os.path.join(log_dir, "sample_color.jpg")
            cv2.imwrite(name, mask)
Пример #4
0
def run_stage_training(model: Union[TimmRgbModel,
                                    YCrCbModel], config: StageConfig,
                       exp_config: ExperimenetConfig, experiment_dir: str):
    # Preparing model
    freeze_model(model, freeze_bn=config.freeze_bn)

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=exp_config.data_dir,
        image_size=config.image_size,
        augmentation=config.augmentations,
        balance=config.balance,
        fast=config.fast,
        fold=exp_config.fold,
        features=model.required_features,
        obliterate_p=config.obliterate_p,
    )

    criterions_dict, loss_callbacks = get_criterions(
        modification_flag=config.modification_flag_loss,
        modification_type=config.modification_type_loss,
        embedding_loss=config.embedding_loss,
        feature_maps_loss=config.feature_maps_loss,
        num_epochs=config.epochs,
        mixup=config.mixup,
        cutmix=config.cutmix,
        tsa=config.tsa,
    )

    callbacks = loss_callbacks + [
        OptimizerCallback(accumulation_steps=config.accumulation_steps,
                          decouple_weight_decay=False),
        HyperParametersCallback(
            hparam_dict={
                "model": exp_config.model_name,
                "scheduler": config.schedule,
                "optimizer": config.optimizer,
                "augmentations": config.augmentations,
                "size": config.image_size[0],
                "weight_decay": config.weight_decay,
            }),
    ]

    if config.show:
        callbacks += [
            ShowPolarBatchesCallback(draw_predictions,
                                     metric="loss",
                                     minimize=True)
        ]

    loaders = collections.OrderedDict()
    loaders["train"] = DataLoader(
        train_ds,
        batch_size=config.train_batch_size,
        num_workers=exp_config.num_workers,
        pin_memory=True,
        drop_last=True,
        shuffle=train_sampler is None,
        sampler=train_sampler,
    )

    loaders["valid"] = DataLoader(valid_ds,
                                  batch_size=config.valid_batch_size,
                                  num_workers=exp_config.num_workers,
                                  pin_memory=True)

    print("Stage            :", config.stage_name)
    print("  FP16 mode      :", config.fp16)
    print("  Fast mode      :", config.fast)
    print("  Epochs         :", config.epochs)
    print("  Workers        :", exp_config.num_workers)
    print("  Data dir       :", exp_config.data_dir)
    print("  Experiment dir :", experiment_dir)
    print("Data              ")
    print("  Augmentations  :", config.augmentations)
    print("  Obliterate (%) :", config.obliterate_p)
    print("  Negative images:", config.negative_image_dir)
    print("  Train size     :", len(loaders["train"]), "batches",
          len(train_ds), "samples")
    print("  Valid size     :", len(loaders["valid"]), "batches",
          len(valid_ds), "samples")
    print("  Image size     :", config.image_size)
    print("  Balance        :", config.balance)
    print("  Mixup          :", config.mixup)
    print("  CutMix         :", config.cutmix)
    print("  TSA            :", config.tsa)
    print("Model            :", exp_config.model_name)
    print("  Parameters     :", count_parameters(model))
    print("  Dropout        :", exp_config.dropout)
    print("Optimizer        :", config.optimizer)
    print("  Learning rate  :", config.learning_rate)
    print("  Weight decay   :", config.weight_decay)
    print("  Scheduler      :", config.schedule)
    print("  Batch sizes    :", config.train_batch_size,
          config.valid_batch_size)
    print("Losses            ")
    print("  Flag           :", config.modification_flag_loss)
    print("  Type           :", config.modification_type_loss)
    print("  Embedding      :", config.embedding_loss)
    print("  Feature maps   :", config.feature_maps_loss)

    optimizer = get_optimizer(
        config.optimizer,
        get_optimizable_parameters(model),
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
    )
    scheduler = get_scheduler(
        config.schedule,
        optimizer,
        lr=config.learning_rate,
        num_epochs=config.epochs,
        batches_in_epoch=len(loaders["train"]),
    )
    if isinstance(scheduler, CyclicLR):
        callbacks += [SchedulerCallback(mode="batch")]

    # model training
    runner = SupervisedRunner(input_key=model.required_features,
                              output_key=None)
    runner.train(
        fp16=config.fp16,
        model=model,
        criterion=criterions_dict,
        optimizer=optimizer,
        scheduler=scheduler,
        callbacks=callbacks,
        loaders=loaders,
        logdir=os.path.join(experiment_dir, config.stage_name),
        num_epochs=config.epochs,
        verbose=config.verbose,
        main_metric=config.main_metric,
        minimize_metric=config.main_metric_minimize,
        checkpoint_data={"config": config},
    )

    del optimizer, loaders, callbacks, runner

    best_checkpoint = os.path.join(experiment_dir, config.stage_name,
                                   "checkpoints", "best.pth")
    model_checkpoint = os.path.join(experiment_dir,
                                    f"{exp_config.checkpoint_prefix}.pth")
    clean_checkpoint(best_checkpoint, model_checkpoint)

    # Restore state of best model
    if config.restore_best:
        unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

    # Some memory cleanup
    torch.cuda.empty_cache()
    gc.collect()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration")
    parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--cache", action="store_true")
    parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64")
    parser.add_argument(
        "-wbs", "--warmup-batch-size", type=int, default=None, help="Batch Size during training, e.g. -b 64"
    )
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    parser.add_argument(
        "-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement"
    )
    parser.add_argument("-fe", "--freeze-encoder", action="store_true", help="Freeze encoder parameters for N epochs")
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")

    parser.add_argument(
        "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument(
        "--modification-type-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--mask-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--bits-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--mixup", action="store_true")
    parser.add_argument("--cutmix", action="store_true")
    parser.add_argument("--tsa", action="store_true")
    parser.add_argument("--fold", default=None, type=int)
    parser.add_argument("-s", "--scheduler", default=None, type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=0, type=float, help="Dropout before head layer")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument(
        "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    assert (
        args.modification_flag_loss or args.modification_type_loss or args.embedding_loss
    ), "At least one of losses must be set"

    modification_flag_loss = args.modification_flag_loss
    modification_type_loss = args.modification_type_loss
    embedding_loss = args.embedding_loss
    feature_maps_loss = args.feature_maps_loss
    mask_loss = args.mask_loss
    bits_loss = args.bits_loss

    data_dir = args.data_dir
    cache = args.cache
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    optimizer_name = args.optimizer
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    verbose = args.verbose
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    balance = args.balance
    freeze_bn = args.freeze_bn
    train_batch_size = args.batch_size
    mixup = args.mixup
    cutmix = args.cutmix
    tsa = args.tsa
    obliterate_p = args.obliterate
    negative_image_dir = args.negative_image_dir

    # Compute batch size for validation
    valid_batch_size = train_batch_size

    current_time = datetime.now().strftime("%b%d_%H_%M")

    main_metric = "loss"
    main_metric_minimize = True

    x_train = np.load(f"embeddings_x_train_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy")
    y_train = np.load(f"embeddings_y_train_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy")

    x_valid = np.load(f"embeddings_x_holdout_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy")
    y_valid = np.load(f"embeddings_y_holdout_Gf3_Hnrmishf2_Hnrmishf1_Kmishf0.npy")

    print(x_train.shape, x_valid.shape)
    print(np.bincount(y_train), np.bincount(y_valid))

    train_ds = StackerDataset(x_train, y_train)
    valid_ds = StackerDataset(x_valid, y_valid)

    criterions_dict, loss_callbacks = get_criterions(
        modification_flag=modification_flag_loss,
        modification_type=modification_type_loss,
        embedding_loss=None,
        feature_maps_loss=None,
        mask_loss=None,
        bits_loss=None,
        num_epochs=num_epochs,
        mixup=mixup,
        cutmix=None,
        tsa=tsa,
    )

    callbacks = loss_callbacks + [
        OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
        HyperParametersCallback(
            hparam_dict={
                "scheduler": scheduler_name,
                "optimizer": optimizer_name,
                "augmentations": augmentations,
                "weight_decay": weight_decay,
            }
        ),
    ]

    loaders = collections.OrderedDict()
    loaders["train"] = DataLoader(
        train_ds, batch_size=train_batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, shuffle=True
    )

    loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

    model: nn.Module = StackingModel(x_train.shape[1], dropout=dropout).cuda()

    optimizer = get_optimizer(
        optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
    )
    scheduler = get_scheduler(
        scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
    )
    if isinstance(scheduler, CyclicLR):
        callbacks += [SchedulerCallback(mode="batch")]

    checkpoint_prefix = f"{current_time}_stacking"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if mixup:
        checkpoint_prefix += "_mixup"

    if cutmix:
        checkpoint_prefix += "_cutmix"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    print("Train session    :", checkpoint_prefix)
    print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
    print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
    print("  FP16 mode      :", fp16)
    print("  Fast mode      :", args.fast)
    print("  Epochs         :", num_epochs)
    print("  Workers        :", num_workers)
    print("  Data dir       :", data_dir)
    print("  Log dir        :", log_dir)
    print("  Cache          :", cache)
    print("Data              ")
    print("  Augmentations  :", augmentations)
    print("  Obliterate (%) :", obliterate_p)
    print("  Negative images:", negative_image_dir)
    print("  Balance        :", balance)
    print("  Mixup          :", mixup)
    print("  CutMix         :", cutmix)
    print("  TSA            :", tsa)
    # print("Model            :", model_name)
    print("  Parameters     :", count_parameters(model))
    print("  Dropout        :", dropout)
    print("Optimizer        :", optimizer_name)
    print("  Learning rate  :", learning_rate)
    print("  Weight decay   :", weight_decay)
    print("  Scheduler      :", scheduler_name)
    print("  Batch sizes    :", train_batch_size, valid_batch_size)
    print("Losses            ")
    print("  Flag           :", modification_flag_loss)
    print("  Type           :", modification_type_loss)
    print("  Embedding      :", embedding_loss)
    print("  Feature maps   :", feature_maps_loss)
    print("  Mask           :", mask_loss)
    print("  Bits           :", bits_loss)

    # model training
    runner = SupervisedRunner(input_key=[INPUT_EMBEDDING_KEY], output_key=None)
    runner.train(
        fp16=fp16,
        model=model,
        criterion=criterions_dict,
        optimizer=optimizer,
        scheduler=scheduler,
        callbacks=callbacks,
        loaders=loaders,
        logdir=os.path.join(log_dir, "main"),
        num_epochs=num_epochs,
        verbose=verbose,
        main_metric=main_metric,
        minimize_metric=main_metric_minimize,
        checkpoint_data={"cmd_args": vars(args)},
    )

    del optimizer, loaders, runner, callbacks

    best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")
    model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")

    # Restore state of best model
    clean_checkpoint(best_checkpoint, model_checkpoint)
Пример #6
0
def _run(
    model: nn.Module,
    prefix: str,
    data_dir: str,
    fold: int,
    epochs: int,
    batch_size: int,
    optimizer_name: str,
    augmentations="light",
    learning_rate=1e-4,
    weight_decay=0,
    fast=False,
):
    def train_fn(epoch, train_dataloader, optimizer, criterion, scheduler,
                 device):
        model.train()

        for batch_idx, batch_data in enumerate(train_dataloader):
            optimizer.zero_grad()

            batch_data = any2device(batch_data, device)
            outputs = model(**batch_data)

            y_pred = outputs[OUTPUT_PRED_MODIFICATION_TYPE]
            y_true = batch_data[INPUT_TRUE_MODIFICATION_TYPE]

            loss = criterion(y_pred, y_true)

            if batch_idx % 100:
                xm.master_print(f"Batch: {batch_idx}, loss: {loss.item()}")

            loss.backward()
            xm.optimizer_step(optimizer)

            if scheduler is not None:
                scheduler.step()

    def valid_fn(epoch, valid_dataloader, criterion, device):
        model.eval()

        pred_scores = []
        true_scores = []

        for batch_idx, batch_data in enumerate(valid_dataloader):
            batch_data = any2device(batch_data, device)
            outputs = model(**batch_data)

            y_pred = outputs[OUTPUT_PRED_MODIFICATION_TYPE]
            y_true = batch_data[INPUT_TRUE_MODIFICATION_TYPE]

            loss = criterion(y_pred, y_true)

            pred_scores.extend(to_numpy(parse_classifier_probas(y_pred)))
            true_scores.extend(to_numpy(y_true))

            xm.master_print(f"Batch: {batch_idx}, loss: {loss.item()}")

        val_wauc = alaska_weighted_auc(xla_all_gather(true_scores, device),
                                       xla_all_gather(pred_scores, device))
        xm.master_print(f"Valid epoch: {epoch}, wAUC: {val_wauc}")
        return val_wauc

    train_dataset, valid_dataset, _ = get_datasets(
        data_dir,
        fold=fold,
        fast=fast,
        augmentation=augmentations,
        features=model.required_features)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  sampler=train_sampler,
                                  num_workers=1)
    valid_dataloader = DataLoader(valid_dataset,
                                  batch_size=batch_size,
                                  sampler=valid_sampler,
                                  num_workers=1,
                                  drop_last=False)

    device = xm.xla_device()
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = get_optimizer(optimizer_name,
                              get_optimizable_parameters(model),
                              learning_rate=learning_rate,
                              weight_decay=weight_decay)
    num_train_steps = int(
        len(train_dataset) / batch_size / xm.xrt_world_size() * epochs)
    xm.master_print(
        f"num_train_steps = {num_train_steps}, world_size={xm.xrt_world_size()}"
    )

    lr_scheduler = ReduceLROnPlateau(optimizer,
                                     mode="max",
                                     factor=0.5,
                                     patience=5,
                                     verbose=True,
                                     min_lr=1e-6)

    best_wauc = 0

    train_begin = time.time()
    for epoch in range(epochs):
        para_loader = pl.ParallelLoader(train_dataloader, [device])

        start = time.time()
        print("*" * 15)
        print(f"EPOCH: {epoch + 1}")
        print("*" * 15)

        print("Training.....")
        train_fn(
            epoch=epoch + 1,
            train_dataloader=para_loader.per_device_loader(device),
            optimizer=optimizer,
            criterion=criterion,
            scheduler=None,
            device=device,
        )

        with torch.no_grad():
            para_loader = pl.ParallelLoader(valid_dataloader, [device])

            print("Validating....")
            val_wauc = valid_fn(
                epoch=epoch + 1,
                valid_dataloader=para_loader.per_device_loader(device),
                criterion=criterion,
                device=device,
            )

            if isinstance(lr_scheduler, ReduceLROnPlateau):
                lr_scheduler.step(val_wauc)

            xm.save(model.state_dict(), f"{prefix}_last.pth")
            if val_wauc > best_wauc:
                best_wauc = val_wauc
                xm.save(model.state_dict(), f"{prefix}_best.pth")
                xm.master_print(f"Saved best checkpoint with wAUC {best_wauc}")

        print(f"Epoch completed in {(time.time() - start) / 60} minutes")
    print(f"Training completed in {(time.time() - train_begin) / 60} minutes")