Esempio n. 1
0
def run_whole_training(experiment_name: str,
                       exp_config: ExperimenetConfig,
                       runs_dir="runs"):
    model = get_model(exp_config.model_name, dropout=exp_config.dropout).cuda()

    if exp_config.transfer_from_checkpoint:
        transfer_checkpoint = fs.auto_file(exp_config.transfer_from_checkpoint)
        print("Transferring weights from model checkpoint",
              transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if exp_config.resume_from_checkpoint:
        checkpoint = load_checkpoint(
            fs.auto_file(exp_config.resume_from_checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", exp_config.resume_from_checkpoint)
        report_checkpoint(checkpoint)

    experiment_dir = os.path.join(runs_dir, experiment_name)
    os.makedirs(experiment_dir, exist_ok=False)

    config_fname = os.path.join(experiment_dir, f"config.json")
    with open(config_fname, "w") as f:
        f.write(json.dumps(jsonpickle.encode(exp_config), indent=2))

    for stage in exp_config.stages:
        run_stage_training(model,
                           stage,
                           exp_config,
                           experiment_dir=experiment_dir)
def evaluate(checkpoints, fast=False):
    num_folds = len(checkpoints)

    for tta in [None, 'flip', 'd4']:
        oof_predictions = []
        oof_scores = []

        for fold in range(4):
            _, valid_ds = get_datasets(use_aptos2015=not fast,
                                       use_aptos2019=True,
                                       use_messidor=not fast,
                                       use_idrid=not fast,
                                       fold=fold,
                                       folds=num_folds)

            checkpoint_file = fs.auto_file(checkpoints[fold])
            p = run_model_inference_via_dataset(
                model_checkpoint=checkpoint_file,
                dataset=valid_ds,
                tta=tta,
                batch_size=16,
                apply_softmax=True,
                workers=6)

            oof_diagnosis = cls_predictions_to_submission(
                p)['diagnosis'].values
            oof_score = cohen_kappa_score(oof_diagnosis,
                                          valid_ds.targets,
                                          weights='quadratic')
            oof_scores.append(oof_score)
            oof_predictions.append(p)

        # averaged_predictions = average_predictions(oof_predictions)
        print(tta, np.mean(oof_scores), np.std(oof_scores))
def weighted_model(checkpoint_fname: str, weights, activation: str):
    model, info = model_from_checkpoint(fs.auto_file(checkpoint_fname,
                                                     where="ensemble"),
                                        activation_after=activation,
                                        report=False,
                                        classifiers=False)
    model = ApplyWeights(model, weights)
    return model, info
Esempio n. 4
0
def ela_wider_resnet38(num_classes=4, pretrained=True, dropout=0):
    from alaska2.models.backbones.wider_resnet import wider_resnet38

    encoder = wider_resnet38(in_chans=6)
    if pretrained:
        checkpoint = torch.load(
            fs.auto_file("wide_resnet38_ipabn_lr_256.pth.tar"),
            map_location="cpu")
        transfer_weights(encoder, checkpoint["state_dict"])
        print("Loaded weights from Mapilary")

    return TimmRgbElaModel(encoder, num_classes=num_classes, dropout=dropout)
def evaluate_averaging(checkpoints, use_aptos2015, use_aptos2019, use_messidor,
                       use_idrid):
    for tta in [None]:
        predictions = []
        scores = []

        train_ds, valid_ds, _ = get_datasets(use_aptos2015=use_aptos2015,
                                             use_aptos2019=use_aptos2019,
                                             use_messidor=use_messidor,
                                             use_idrid=use_idrid)
        full_ds = train_ds + valid_ds
        targets = np.concatenate((train_ds.targets, valid_ds.targets))
        assert len(targets) == len(full_ds)
        for fold in range(4):
            checkpoint_file = fs.auto_file(checkpoints[fold])
            p = run_model_inference_via_dataset(
                model_checkpoint=checkpoint_file,
                dataset=full_ds,
                tta=tta,
                batch_size=16,
                apply_softmax=False,
                workers=6)

            oof_diagnosis = reg_predictions_to_submission(
                p)['diagnosis'].values
            oof_score = cohen_kappa_score(oof_diagnosis,
                                          targets,
                                          weights='quadratic')
            scores.append(oof_score)
            predictions.append(p)

        mean_predictions = average_predictions(predictions, 'mean')
        mean_score = cohen_kappa_score(reg_predictions_to_submission(
            mean_predictions)['diagnosis'].values,
                                       targets,
                                       weights='quadratic')
        # geom_predictions = average_predictions(predictions, 'geom')
        # geom_score = cohen_kappa_score(reg_predictions_to_submission(geom_predictions)['diagnosis'].values,
        #                                targets, weights='quadratic')

        name = ''
        if use_aptos2015: name += 'aptos2015'
        if use_aptos2019: name += 'aptos2019'
        if use_messidor: name += 'messidor'
        if use_idrid: name += 'idrid'
        print(name, 'TTA', tta, 'Mean', np.mean(scores), 'std', np.std(scores),
              'MeanAvg', mean_score)
Esempio n. 6
0
def test_crop_black_regions():
    for image_fname in [
            # 'data/train_images/0a4e1a29ffff.png',
            # 'data/train_images/0a61bddab956.png',
            # 'tests/19150_right.jpeg',
            fs.auto_file('3704_left.jpeg', where='..')
    ]:
        image = cv2.imread(image_fname)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

        # image = A.CLAHE(always_apply=True)(image=image)['image']
        cropped = crop_black(image)

        f, ax = plt.subplots(1, 2)
        ax[0].imshow(image)
        ax[1].imshow(cropped)
        f.show()
Esempio n. 7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-df', type=str)
    parser.add_argument('-nf', '--num-folds', default=4)
    args = parser.parse_args()

    num_folds = args.num_folds

    train_set = pd.read_csv('data/train.csv')
    df = pd.read_csv(auto_file(args.df))
    train_set = train_set.merge(df, on='id_code')
    train_set = train_set.sort_values(by='is_test')

    folds = np.arange(num_folds).tolist() * len(train_set)
    folds = folds[:len(train_set)]
    train_set['fold'] = folds

    train_set.to_csv('data/train_with_folds.csv', index=None)
def test_inference_pd():
    test_csv = pd.read_csv('data/aptos-2019/test.csv')

    checkpoints = [
        'runs/Aug09_16_47/seresnext50_gap/512/medium/hopeful_easley/fold0/checkpoints/seresnext50_gap_512_medium_aptos2019_idrid_fold0_hopeful_easley.pth',
    ]

    for checkpoint in checkpoints:
        predictions = run_model_inference(
            model_checkpoint=fs.auto_file(checkpoint),
            test_csv=test_csv,
            images_dir='test_images_768',
            data_dir='data/aptos-2019',
            apply_softmax=True,
            need_features=False,
            batch_size=8)

        predictions['diagnosis'] = predictions['diagnosis'].apply(
            regression_to_class).apply(int)
        submission = predictions[['id_code', 'diagnosis']]
        print(submission)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str, default="unet", help="")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default=None,
                        required=True,
                        help="Data dir")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        required=True,
        help="Checkpoint filename to use as initial model weights",
    )
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=16,
                        help="Batch size for inference")
    parser.add_argument("-tta",
                        "--tta",
                        default=None,
                        type=str,
                        help="Type of TTA to use [fliplr, d4]")
    args = parser.parse_args()

    data_dir = args.data_dir
    checkpoint_file = auto_file(args.checkpoint)
    run_dir = os.path.dirname(os.path.dirname(checkpoint_file))
    out_dir = os.path.join(run_dir, "submit")
    os.makedirs(out_dir, exist_ok=True)

    checkpoint = load_checkpoint(checkpoint_file)
    checkpoint_epoch = checkpoint["epoch"]
    print("Loaded model weights from", args.checkpoint)
    print("Epoch   :", checkpoint_epoch)
    print(
        "Metrics (Train):",
        "IoU:",
        checkpoint["epoch_metrics"]["train"]["jaccard"],
        "Acc:",
        checkpoint["epoch_metrics"]["train"]["accuracy"],
    )
    print(
        "Metrics (Valid):",
        "IoU:",
        checkpoint["epoch_metrics"]["valid"]["jaccard"],
        "Acc:",
        checkpoint["epoch_metrics"]["valid"]["accuracy"],
    )

    model = get_model(args.model)
    unpack_checkpoint(checkpoint, model=model)
    # threshold = checkpoint["epoch_metrics"]["valid"].get("optimized_jaccard/threshold", 0.5)
    threshold = 0.5
    print("Using threshold", threshold)

    model = nn.Sequential(PickModelOutput(model, OUTPUT_MASK_KEY),
                          nn.Sigmoid())

    if args.tta == "fliplr":
        model = TTAWrapper(model, fliplr_image2mask)
    elif args.tta == "flipscale":
        model = TTAWrapper(model, fliplr_image2mask)
        model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128])
    elif args.tta == "d4":
        model = TTAWrapper(model, d4_image2mask)
    else:
        pass

    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    model = model.eval()

    mask = predict(model,
                   read_inria_image("sample_color.jpg"),
                   image_size=(512, 512),
                   batch_size=args.batch_size)
    mask = ((mask > threshold) * 255).astype(np.uint8)
    name = os.path.join(run_dir, "sample_color.jpg")
    cv2.imwrite(name, mask)

    test_predictions_dir = os.path.join(out_dir, "test_predictions")
    test_predictions_dir_compressed = os.path.join(
        out_dir, "test_predictions_compressed")

    if args.tta is not None:
        test_predictions_dir += f"_{args.tta}"
        test_predictions_dir_compressed += f"_{args.tta}"

    os.makedirs(test_predictions_dir, exist_ok=True)
    os.makedirs(test_predictions_dir_compressed, exist_ok=True)

    test_images = find_in_dir(os.path.join(data_dir, "test", "images"))
    for fname in tqdm(test_images, total=len(test_images)):
        image = read_inria_image(fname)
        mask = predict(model,
                       image,
                       image_size=(512, 512),
                       batch_size=args.batch_size)
        mask = ((mask > threshold) * 255).astype(np.uint8)
        name = os.path.join(test_predictions_dir, os.path.basename(fname))
        cv2.imwrite(name, mask)

        name_compressed = os.path.join(test_predictions_dir_compressed,
                                       os.path.basename(fname))
        command = (
            "gdal_translate --config GDAL_PAM_ENABLED NO -co COMPRESS=CCITTFAX4 -co NBITS=1 "
            + name + " " + name_compressed)
        subprocess.call(command, shell=True)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('--mixup', action='store_true')
    parser.add_argument('--balance', action='store_true')
    parser.add_argument('--balance-datasets', action='store_true')
    parser.add_argument('--swa', action='store_true')
    parser.add_argument('--show', action='store_true')
    parser.add_argument('--use-idrid', action='store_true')
    parser.add_argument('--use-messidor', action='store_true')
    parser.add_argument('--use-aptos2015', action='store_true')
    parser.add_argument('--use-aptos2019', action='store_true')
    parser.add_argument('-v', '--verbose', action='store_true')
    parser.add_argument('--coarse', action='store_true')
    parser.add_argument('-acc',
                        '--accumulation-steps',
                        type=int,
                        default=1,
                        help='Number of batches to process')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='data',
                        help='Data directory')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='resnet18_gap',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=100,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    parser.add_argument('-f',
                        '--fold',
                        action='append',
                        type=int,
                        default=None)
    parser.add_argument('-ft', '--fine-tune', default=0, type=int)
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('--criterion-reg',
                        type=str,
                        default=None,
                        nargs='+',
                        help='Criterion')
    parser.add_argument('--criterion-ord',
                        type=str,
                        default=None,
                        nargs='+',
                        help='Criterion')
    parser.add_argument('--criterion-cls',
                        type=str,
                        default=['ce'],
                        nargs='+',
                        help='Criterion')
    parser.add_argument('-l1',
                        type=float,
                        default=0,
                        help='L1 regularization loss')
    parser.add_argument('-l2',
                        type=float,
                        default=0,
                        help='L2 regularization loss')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument('-p',
                        '--preprocessing',
                        default=None,
                        help='Preprocessing method')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-w',
                        '--workers',
                        default=multiprocessing.cpu_count(),
                        type=int,
                        help='Num workers')
    parser.add_argument('-a',
                        '--augmentations',
                        default='medium',
                        type=str,
                        help='')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    parser.add_argument('-t', '--transfer', default=None, type=str, help='')
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('-s',
                        '--scheduler',
                        default='multistep',
                        type=str,
                        help='')
    parser.add_argument('--size',
                        default=512,
                        type=int,
                        help='Image size for training & inference')
    parser.add_argument('-wd',
                        '--weight-decay',
                        default=0,
                        type=float,
                        help='L2 weight decay')
    parser.add_argument('-wds',
                        '--weight-decay-step',
                        default=None,
                        type=float,
                        help='L2 weight decay step to add after each epoch')
    parser.add_argument('-d',
                        '--dropout',
                        default=0.0,
                        type=float,
                        help='Dropout before head layer')
    parser.add_argument(
        '--warmup',
        default=0,
        type=int,
        help=
        'Number of warmup epochs with 0.1 of the initial LR and frozed encoder'
    )
    parser.add_argument('-x',
                        '--experiment',
                        default=None,
                        type=str,
                        help='Dropout before head layer')

    args = parser.parse_args()

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    l1 = args.l1
    l2 = args.l2
    early_stopping = args.early_stopping
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = (args.size, args.size)
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    fine_tune = args.fine_tune
    criterion_reg_name = args.criterion_reg
    criterion_cls_name = args.criterion_cls
    criterion_ord_name = args.criterion_ord
    folds = args.fold
    mixup = args.mixup
    balance = args.balance
    balance_datasets = args.balance_datasets
    use_swa = args.swa
    show_batches = args.show
    scheduler_name = args.scheduler
    verbose = args.verbose
    weight_decay = args.weight_decay
    use_idrid = args.use_idrid
    use_messidor = args.use_messidor
    use_aptos2015 = args.use_aptos2015
    use_aptos2019 = args.use_aptos2019
    warmup = args.warmup
    dropout = args.dropout
    use_unsupervised = False
    experiment = args.experiment
    preprocessing = args.preprocessing
    weight_decay_step = args.weight_decay_step
    coarse_grading = args.coarse
    class_names = get_class_names(coarse_grading)

    assert use_aptos2015 or use_aptos2019 or use_idrid or use_messidor

    current_time = datetime.now().strftime('%b%d_%H_%M')
    random_name = get_random_name()

    if folds is None or len(folds) == 0:
        folds = [None]

    for fold in folds:
        torch.cuda.empty_cache()
        checkpoint_prefix = f'{model_name}_{args.size}_{augmentations}'

        if preprocessing is not None:
            checkpoint_prefix += f'_{preprocessing}'
        if use_aptos2019:
            checkpoint_prefix += '_aptos2019'
        if use_aptos2015:
            checkpoint_prefix += '_aptos2015'
        if use_messidor:
            checkpoint_prefix += '_messidor'
        if use_idrid:
            checkpoint_prefix += '_idrid'
        if coarse_grading:
            checkpoint_prefix += '_coarse'

        if fold is not None:
            checkpoint_prefix += f'_fold{fold}'

        checkpoint_prefix += f'_{random_name}'

        if experiment is not None:
            checkpoint_prefix = experiment

        directory_prefix = f'{current_time}/{checkpoint_prefix}'
        log_dir = os.path.join('runs', directory_prefix)
        os.makedirs(log_dir, exist_ok=False)

        config_fname = os.path.join(log_dir, f'{checkpoint_prefix}.json')
        with open(config_fname, 'w') as f:
            train_session_args = vars(args)
            f.write(json.dumps(train_session_args, indent=2))

        set_manual_seed(args.seed)
        num_classes = len(class_names)
        model = get_model(model_name, num_classes=num_classes,
                          dropout=dropout).cuda()

        if args.transfer:
            transfer_checkpoint = fs.auto_file(args.transfer)
            print("Transfering weights from model checkpoint",
                  transfer_checkpoint)
            checkpoint = load_checkpoint(transfer_checkpoint)
            pretrained_dict = checkpoint['model_state_dict']

            for name, value in pretrained_dict.items():
                try:
                    model.load_state_dict(collections.OrderedDict([(name,
                                                                    value)]),
                                          strict=False)
                except Exception as e:
                    print(e)

            report_checkpoint(checkpoint)

        if args.checkpoint:
            checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
            unpack_checkpoint(checkpoint, model=model)
            report_checkpoint(checkpoint)

        train_ds, valid_ds, train_sizes = get_datasets(
            data_dir=data_dir,
            use_aptos2019=use_aptos2019,
            use_aptos2015=use_aptos2015,
            use_idrid=use_idrid,
            use_messidor=use_messidor,
            use_unsupervised=False,
            coarse_grading=coarse_grading,
            image_size=image_size,
            augmentation=augmentations,
            preprocessing=preprocessing,
            target_dtype=int,
            fold=fold,
            folds=4)

        train_loader, valid_loader = get_dataloaders(
            train_ds,
            valid_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            train_sizes=train_sizes,
            balance=balance,
            balance_datasets=balance_datasets,
            balance_unlabeled=False)

        loaders = collections.OrderedDict()
        loaders["train"] = train_loader
        loaders["valid"] = valid_loader

        print('Datasets         :', data_dir)
        print('  Train size     :', len(train_loader),
              len(train_loader.dataset))
        print('  Valid size     :', len(valid_loader),
              len(valid_loader.dataset))
        print('  Aptos 2019     :', use_aptos2019)
        print('  Aptos 2015     :', use_aptos2015)
        print('  IDRID          :', use_idrid)
        print('  Messidor       :', use_messidor)
        print('Train session    :', directory_prefix)
        print('  FP16 mode      :', fp16)
        print('  Fast mode      :', fast)
        print('  Mixup          :', mixup)
        print('  Balance cls.   :', balance)
        print('  Balance ds.    :', balance_datasets)
        print('  Warmup epoch   :', warmup)
        print('  Train epochs   :', num_epochs)
        print('  Fine-tune ephs :', fine_tune)
        print('  Workers        :', num_workers)
        print('  Fold           :', fold)
        print('  Log dir        :', log_dir)
        print('  Augmentations  :', augmentations)
        print('Model            :', model_name)
        print('  Parameters     :', count_parameters(model))
        print('  Image size     :', image_size)
        print('  Dropout        :', dropout)
        print('  Classes        :', class_names, num_classes)
        print('Optimizer        :', optimizer_name)
        print('  Learning rate  :', learning_rate)
        print('  Batch size     :', batch_size)
        print('  Criterion (cls):', criterion_cls_name)
        print('  Criterion (reg):', criterion_reg_name)
        print('  Criterion (ord):', criterion_ord_name)
        print('  Scheduler      :', scheduler_name)
        print('  Weight decay   :', weight_decay, weight_decay_step)
        print('  L1 reg.        :', l1)
        print('  L2 reg.        :', l2)
        print('  Early stopping :', early_stopping)

        # model training
        callbacks = []
        criterions = {}

        main_metric = 'cls/kappa'
        if criterion_reg_name is not None:
            cb, crits = get_reg_callbacks(criterion_reg_name,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if criterion_ord_name is not None:
            cb, crits = get_ord_callbacks(criterion_ord_name,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if criterion_cls_name is not None:
            cb, crits = get_cls_callbacks(criterion_cls_name,
                                          num_classes=num_classes,
                                          num_epochs=num_epochs,
                                          class_names=class_names,
                                          show=show_batches)
            callbacks += cb
            criterions.update(crits)

        if l1 > 0:
            callbacks += [
                LPRegularizationCallback(start_wd=l1,
                                         end_wd=l1,
                                         schedule=None,
                                         prefix='l1',
                                         p=1)
            ]

        if l2 > 0:
            callbacks += [
                LPRegularizationCallback(start_wd=l2,
                                         end_wd=l2,
                                         schedule=None,
                                         prefix='l2',
                                         p=2)
            ]

        callbacks += [CustomOptimizerCallback()]

        runner = SupervisedRunner(input_key='image')

        # Pretrain/warmup
        if warmup:
            set_trainable(model.encoder, False, False)
            optimizer = get_optimizer('Adam',
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate * 0.1)

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=None,
                         callbacks=callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'warmup'),
                         num_epochs=warmup,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            del optimizer

        # Main train
        if num_epochs:
            set_trainable(model.encoder, True, False)

            optimizer = get_optimizer(optimizer_name,
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate,
                                      weight_decay=weight_decay)

            if use_swa:
                from torchcontrib.optim import SWA
                optimizer = SWA(optimizer,
                                swa_start=len(train_loader),
                                swa_freq=512)

            scheduler = get_scheduler(scheduler_name,
                                      optimizer,
                                      lr=learning_rate,
                                      num_epochs=num_epochs,
                                      batches_in_epoch=len(train_loader))

            # Additional callbacks that specific to main stage only added here to copy of callbacks
            main_stage_callbacks = callbacks
            if early_stopping:
                es_callback = EarlyStoppingCallback(early_stopping,
                                                    min_delta=1e-4,
                                                    metric=main_metric,
                                                    minimize=False)
                main_stage_callbacks = callbacks + [es_callback]

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=scheduler,
                         callbacks=main_stage_callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'main'),
                         num_epochs=num_epochs,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            del optimizer, scheduler

            best_checkpoint = os.path.join(log_dir, 'main', 'checkpoints',
                                           'best.pth')
            model_checkpoint = os.path.join(log_dir, 'main', 'checkpoints',
                                            f'{checkpoint_prefix}.pth')
            clean_checkpoint(best_checkpoint, model_checkpoint)

            # Restoring best model from checkpoint
            checkpoint = load_checkpoint(best_checkpoint)
            unpack_checkpoint(checkpoint, model=model)
            report_checkpoint(checkpoint)

        # Stage 3 - Fine tuning
        if fine_tune:
            set_trainable(model.encoder, False, False)
            optimizer = get_optimizer(optimizer_name,
                                      get_optimizable_parameters(model),
                                      learning_rate=learning_rate)
            scheduler = get_scheduler('multistep',
                                      optimizer,
                                      lr=learning_rate,
                                      num_epochs=fine_tune,
                                      batches_in_epoch=len(train_loader))

            runner.train(fp16=fp16,
                         model=model,
                         criterion=criterions,
                         optimizer=optimizer,
                         scheduler=scheduler,
                         callbacks=callbacks,
                         loaders=loaders,
                         logdir=os.path.join(log_dir, 'finetune'),
                         num_epochs=fine_tune,
                         verbose=verbose,
                         main_metric=main_metric,
                         minimize_metric=False,
                         checkpoint_data={"cmd_args": vars(args)})

            best_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints',
                                           'best.pth')
            model_checkpoint = os.path.join(log_dir, 'finetune', 'checkpoints',
                                            f'{checkpoint_prefix}.pth')
            clean_checkpoint(best_checkpoint, model_checkpoint)
Esempio n. 11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        required=True,
                        help='Data directory for INRIA sattelite dataset')
    parser.add_argument('-m', '--model', type=str, default='unet', help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=150,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    # parser.add_argument('-f', '--fold', default=None, required=True, type=int, help='Fold to train')
    #     # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    #     # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-3,
                        help='Initial learning rate')
    parser.add_argument('-l',
                        '--criterion',
                        type=str,
                        default='bce',
                        help='Criterion')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-w',
                        '--workers',
                        default=8,
                        type=int,
                        help='Num workers')

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = (512, 512)

    train_loader, valid_loader = get_dataloaders(data_dir=data_dir,
                                                 batch_size=batch_size,
                                                 num_workers=num_workers,
                                                 image_size=image_size,
                                                 fast=args.fast)

    model = maybe_cuda(get_model(model_name, image_size=image_size))
    criterion = get_loss(args.criterion)
    optimizer = get_optimizer(optimizer_name, model.parameters(),
                              learning_rate)

    loaders = collections.OrderedDict()
    loaders["train"] = train_loader
    loaders["valid"] = valid_loader

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[10, 20, 40],
                                                     gamma=0.3)

    # model runner
    runner = SupervisedRunner()

    if args.checkpoint:
        checkpoint = UtilsFactory.load_checkpoint(auto_file(args.checkpoint))
        UtilsFactory.unpack_checkpoint(checkpoint, model=model)

        checkpoint_epoch = checkpoint['epoch']
        print('Loaded model weights from', args.checkpoint)
        print('Epoch   :', checkpoint_epoch)
        print('Metrics:', checkpoint['epoch_metrics'])

        # try:
        #     UtilsFactory.unpack_checkpoint(checkpoint, optimizer=optimizer)
        # except Exception as e:
        #     print('Failed to restore optimizer state', e)

        # try:
        #     UtilsFactory.unpack_checkpoint(checkpoint, scheduler=scheduler)
        # except Exception as e:
        #     print('Failed to restore scheduler state', e)

        print('Loaded model weights from', args.checkpoint)

    current_time = datetime.now().strftime('%b%d_%H_%M')
    prefix = f'{current_time}_{args.model}_{args.criterion}'
    log_dir = os.path.join('runs', prefix)
    os.makedirs(log_dir, exist_ok=False)

    print('Train session:', prefix)
    print('\tFast mode  :', args.fast)
    print('\tEpochs     :', num_epochs)
    print('\tWorkers    :', num_workers)
    print('\tData dir   :', data_dir)
    print('\tLog dir    :', log_dir)
    print('\tTrain size :', len(train_loader), len(train_loader.dataset))
    print('\tValid size :', len(valid_loader), len(valid_loader.dataset))
    print('Model:', model_name)
    print('\tParameters:', count_parameters(model))
    print('\tImage size:', image_size)
    print('Optimizer:', optimizer_name)
    print('\tLearning rate:', learning_rate)
    print('\tBatch size   :', batch_size)
    print('\tCriterion    :', args.criterion)

    # model training
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        callbacks=[
            # OneCycleLR(
            #     cycle_len=num_epochs,
            #     div_factor=10,
            #     increase_fraction=0.3,
            #     momentum_range=(0.95, 0.85)),
            PixelAccuracyMetric(),
            EpochJaccardMetric(),
            ShowPolarBatchesCallback(visualize_inria_predictions,
                                     metric='accuracy',
                                     minimize=False),
            # EarlyStoppingCallback(patience=5, min_delta=0.01, metric='jaccard', minimize=False),
        ],
        loaders=loaders,
        logdir=log_dir,
        num_epochs=num_epochs,
        verbose=True,
        main_metric='jaccard',
        minimize_metric=False,
        state_kwargs={"cmd_args": vars(args)})
def main():
    parser = argparse.ArgumentParser()

    ###########################################################################################
    # Distributed-training related stuff
    parser.add_argument("--local_rank", type=int, default=0)
    ###########################################################################################

    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument(
        "-dd",
        "--data-dir",
        type=str,
        help="Data directory for INRIA sattelite dataset",
        default=os.environ.get("INRIA_DATA_DIR"),
    )
    parser.add_argument(
        "-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset"
    )
    parser.add_argument("-m", "--model", type=str, default="b6_unet32_s2", help="")
    parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")
    parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion")
    parser.add_argument(
        "-l2",
        "--criterion2",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 2 mask",
    )
    parser.add_argument(
        "-l4",
        "--criterion4",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 4 mask",
    )
    parser.add_argument(
        "-l8",
        "--criterion8",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 8 mask",
    )
    parser.add_argument(
        "-l16",
        "--criterion16",
        type=str,
        required=False,
        action="append",
        nargs="+",
        help="Criterion for stride 16 mask",
    )

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="hard", type=str, help="")
    parser.add_argument("-tm", "--train-mode", default="random", type=str, help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()

    args.is_master = args.local_rank == 0
    args.distributed = False
    fp16 = args.fp16

    if "WORLD_SIZE" in os.environ:
        args.distributed = int(os.environ["WORLD_SIZE"]) > 1
        args.world_size = int(os.environ["WORLD_SIZE"])
        # args.world_size = torch.distributed.get_world_size()

        print("Initializing init_process_group", args.local_rank)

        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        print("Initialized init_process_group", args.local_rank)

    is_master = args.is_master | (not args.distributed)

    if args.distributed:
        distributed_params = {"rank": args.local_rank, "syncbn": True}
        if args.fp16:
            distributed_params["amp"] = True
    else:
        if args.fp16:
            distributed_params = {}
            distributed_params["amp"] = True
        else:
            distributed_params = False

    set_manual_seed(args.seed + args.local_rank)
    catalyst.utils.set_global_seed(args.seed + args.local_rank)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

    data_dir = args.data_dir
    if data_dir is None:
        raise ValueError("--data-dir must be set")

    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    criterions2 = args.criterion2
    criterions4 = args.criterion4
    criterions8 = args.criterion8
    criterions16 = args.criterion16

    verbose = args.verbose
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    custom_model_kwargs = {"full_size_mask": False}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    if any([criterions2, criterions4, criterions8, criterions16]):
        custom_model_kwargs["need_supervision_masks"] = True
        print("Enabling supervision masks")

    model: nn.Module = get_model(model_name, num_classes=16, **custom_model_kwargs).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    main_metric = "jaccard"

    current_time = datetime.now().strftime("%y%m%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    default_callbacks = [
        JaccardMetricPerImage(
            input_key=INPUT_MASK_KEY,
            output_key=OUTPUT_MASK_KEY,
            prefix="jaccard",
            inputs_to_labels=depth2mask,
            outputs_to_labels=decode_depth_mask,
        ),
    ]

    if is_master:

        default_callbacks += [
            BestMetricCheckpointCallback(target_metric="jaccard", target_metric_minimize=False),
            HyperParametersCallback(
                hparam_dict={
                    "model": model_name,
                    "scheduler": scheduler_name,
                    "optimizer": optimizer_name,
                    "augmentations": augmentations,
                    "size": args.size,
                    "weight_decay": weight_decay,
                    "epochs": num_epochs,
                    "dropout": None if dropout is None else float(dropout),
                }
            ),
        ]

        if show:
            visualize_inria_predictions = partial(
                draw_inria_predictions,
                image_key=INPUT_IMAGE_KEY,
                image_id_key=INPUT_IMAGE_ID_KEY,
                targets_key=INPUT_MASK_KEY,
                outputs_key=OUTPUT_MASK_KEY,
                inputs_to_labels=depth2mask,
                outputs_to_labels=decode_depth_mask,
                max_images=16,
            )
            default_callbacks += [
                ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False),
                ShowPolarBatchesCallback(visualize_inria_predictions, metric="loss", minimize=True),
            ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        buildings_only=(train_mode == "tiles"),
        fast=fast,
        need_weight_mask=need_weight_mask,
        make_mask_target_fn=mask_to_ce_target,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds), "samples")

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(
                data_dir, include_masks=False, augmentation=None, image_size=image_size
            )

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir, include_masks=True, augmentation=augmentations, image_size=image_size
            )

            if args.distributed:
                label_sampler = DistributedSampler(unlabeled_label, args.world_size, args.local_rank, shuffle=False)
            else:
                label_sampler = None

            loaders["infer"] = DataLoader(
                unlabeled_label,
                batch_size=batch_size // 2,
                num_workers=num_workers,
                pin_memory=True,
                sampler=label_sampler,
                drop_last=False,
            )

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="infer",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label), "samples")

        valid_sampler = None
        if args.distributed:
            if train_sampler is not None:
                train_sampler = DistributedSamplerWrapper(
                    train_sampler, args.world_size, args.local_rank, shuffle=True
                )
            else:
                train_sampler = DistributedSampler(train_ds, args.world_size, args.local_rank, shuffle=True)
            valid_sampler = DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False)

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(
            valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=valid_sampler
        )

        loss_callbacks, loss_criterions = get_criterions(
            criterions, criterions2, criterions4, criterions8, criterions16
        )
        callbacks += loss_callbacks

        optimizer = get_optimizer(
            optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        log_dir = os.path.join("runs", checkpoint_prefix)

        if is_master:
            os.makedirs(log_dir, exist_ok=False)
            config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
            with open(config_fname, "w") as f:
                train_session_args = vars(args)
                f.write(json.dumps(train_session_args, indent=2))

            print("Train session    :", checkpoint_prefix)
            print("  FP16 mode      :", fp16)
            print("  Fast mode      :", args.fast)
            print("  Train mode     :", train_mode)
            print("  Epochs         :", num_epochs)
            print("  Workers        :", num_workers)
            print("  Data dir       :", data_dir)
            print("  Log dir        :", log_dir)
            print("  Augmentations  :", augmentations)
            print("  Train size     :", "batches", len(loaders["train"]), "dataset", len(train_ds))
            print("  Valid size     :", "batches", len(loaders["valid"]), "dataset", len(valid_ds))
            print("Model            :", model_name)
            print("  Parameters     :", count_parameters(model))
            print("  Image size     :", image_size)
            print("Optimizer        :", optimizer_name)
            print("  Learning rate  :", learning_rate)
            print("  Batch size     :", batch_size)
            print("  Criterion      :", criterions)
            print("  Use weight mask:", need_weight_mask)
            if args.distributed:
                print("Distributed")
                print("  World size     :", args.world_size)
                print("  Local rank     :", args.local_rank)
                print("  Is master      :", args.is_master)

        # model training
        runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda")
        runner.train(
            fp16=distributed_params,
            model=model,
            criterion=loss_criterions,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        if is_master:
            best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")

            model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")
            clean_checkpoint(best_checkpoint, model_checkpoint)

            unpack_checkpoint(torch.load(model_checkpoint), model=model)

            mask = predict(
                model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size
            )
            mask = ((mask > 0) * 255).astype(np.uint8)
            name = os.path.join(log_dir, "sample_color.jpg")
            cv2.imwrite(name, mask)
Esempio n. 13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--model', type=str, default='unet', help='')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default=None,
                        required=True,
                        help='Data dir')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        required=True,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=16,
                        help='Batch size for inference')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    args = parser.parse_args()

    data_dir = args.data_dir
    checkpoint_file = auto_file(args.checkpoint)
    run_dir = os.path.dirname(os.path.dirname(checkpoint_file))
    out_dir = os.path.join(run_dir, 'evaluation')
    os.makedirs(out_dir, exist_ok=True)

    model = get_model(args.model)

    checkpoint = UtilsFactory.load_checkpoint(checkpoint_file)
    checkpoint_epoch = checkpoint['epoch']
    print('Loaded model weights from', args.checkpoint)
    print('Epoch   :', checkpoint_epoch)
    print('Metrics (Train):', 'IoU:',
          checkpoint['epoch_metrics']['train']['jaccard'], 'Acc:',
          checkpoint['epoch_metrics']['train']['accuracy'])
    print('Metrics (Valid):', 'IoU:',
          checkpoint['epoch_metrics']['valid']['jaccard'], 'Acc:',
          checkpoint['epoch_metrics']['valid']['accuracy'])

    UtilsFactory.unpack_checkpoint(checkpoint, model=model)

    model = model.cuda().eval()

    train_images = find_in_dir(os.path.join(data_dir, 'train', 'images'))
    for fname in tqdm(train_images, total=len(train_images)):
        image = read_rgb_image(fname)
        mask = predict(model,
                       image,
                       tta=args.tta,
                       image_size=(512, 512),
                       batch_size=args.batch_size,
                       activation='sigmoid')
        mask = (mask * 255).astype(np.uint8)
        name = os.path.join(out_dir, os.path.basename(fname))
        cv2.imwrite(name, mask)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        required=True,
                        help="Data directory for INRIA sattelite dataset")
    parser.add_argument("-dd-xview2",
                        "--data-dir-xview2",
                        type=str,
                        required=False,
                        help="Data directory for external xView2 dataset")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        default="resnet34_fpncat128",
                        help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=8,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")
    parser.add_argument("-l",
                        "--criterion",
                        type=str,
                        required=True,
                        action="append",
                        nargs="+",
                        help="Criterion")
    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="hard",
                        type=str,
                        help="")
    parser.add_argument("-tm",
                        "--train-mode",
                        default="random",
                        type=str,
                        help="")
    parser.add_argument("--run-mode", default="fit_predict", type=str, help="")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("-s",
                        "--scheduler",
                        default="multistep",
                        type=str,
                        help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=0.0,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument("--opl", action="store_true")
    parser.add_argument(
        "--warmup",
        default=0,
        type=int,
        help="Number of warmup epochs with reduced LR on encoder parameters")
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    online_pseudolabeling = args.opl
    criterions = args.criterion
    verbose = args.verbose
    warmup = args.warmup
    show = args.show
    use_dsv = args.dsv
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    extra_data_xview2 = args.data_dir_xview2

    run_train = num_epochs > 0
    need_weight_mask = any(c[0] == "wbce" for c in criterions)

    model: nn.Module = get_model(model_name, dropout=dropout).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY,
                              output_key=None,
                              device="cuda")
    main_metric = "optimized_jaccard"
    cmd_args = vars(args)

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if online_pseudolabeling:
        checkpoint_prefix += "_opl"

    if extra_data_xview2:
        checkpoint_prefix += "_with_xview2"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = [
        PixelAccuracyCallback(input_key=INPUT_MASK_KEY,
                              output_key=OUTPUT_MASK_KEY),
        JaccardMetricPerImage(input_key=INPUT_MASK_KEY,
                              output_key=OUTPUT_MASK_KEY,
                              prefix="jaccard"),
        OptimalThreshold(input_key=INPUT_MASK_KEY,
                         output_key=OUTPUT_MASK_KEY,
                         prefix="optimized_jaccard"),
        # OutputDistributionCallback(output_key=OUTPUT_MASK_KEY, activation=torch.sigmoid),
    ]

    if show:
        visualize_inria_predictions = partial(
            draw_inria_predictions,
            image_key=INPUT_IMAGE_KEY,
            image_id_key=INPUT_IMAGE_ID_KEY,
            targets_key=INPUT_MASK_KEY,
            outputs_key=OUTPUT_MASK_KEY,
        )
        default_callbacks += [
            ShowPolarBatchesCallback(visualize_inria_predictions,
                                     metric="accuracy",
                                     minimize=False)
        ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        train_mode=train_mode,
        fast=fast,
        need_weight_mask=need_weight_mask,
    )

    if extra_data_xview2 is not None:
        extra_train_ds, _ = get_xview2_extra_dataset(
            extra_data_xview2,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast,
            need_weight_mask=need_weight_mask,
        )

        weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                        [1] * len(extra_train_ds))
        train_sampler = WeightedRandomSampler(weights,
                                              train_sampler.num_samples * 2)

        train_ds = train_ds + extra_train_ds
        print("Using extra data from xView2 with", len(extra_train_ds),
              "samples")

    # Pretrain/warmup
    if warmup:
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []
        ignore_index = None

        for loss_name, loss_weight in criterions:
            criterion_callback = CriterionCallback(
                prefix="seg_loss/" + loss_name,
                input_key=INPUT_MASK_KEY if loss_name != "wbce" else
                [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
                output_key=OUTPUT_MASK_KEY,
                criterion_key=loss_name,
                multiplier=float(loss_weight),
            )

            criterions_dict[loss_name] = get_loss(loss_name,
                                                  ignore_index=ignore_index)
            callbacks.append(criterion_callback)
            losses.append(criterion_callback.prefix)
            print("Using loss", loss_name, loss_weight)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        parameters = get_lr_decay_parameters(model.named_parameters(),
                                             learning_rate, {"encoder": 0.1})
        optimizer = get_optimizer("RAdam",
                                  parameters,
                                  learning_rate=learning_rate * 0.1)

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True,
                                      shuffle=False,
                                      drop_last=False)

        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=None,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "warmup"),
            num_epochs=warmup,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": cmd_args},
        )

        del optimizer, loaders

        best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints",
                                       "best.pth")
        model_checkpoint = os.path.join(log_dir, "warmup", "checkpoints",
                                        f"{checkpoint_prefix}_warmup.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        torch.cuda.empty_cache()
        gc.collect()

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        ignore_index = None
        if online_pseudolabeling:
            ignore_index = UNLABELED_SAMPLE
            unlabeled_label = get_pseudolabeling_dataset(data_dir,
                                                         include_masks=False,
                                                         augmentation=None,
                                                         image_size=image_size)

            unlabeled_train = get_pseudolabeling_dataset(
                data_dir,
                include_masks=True,
                augmentation=augmentations,
                image_size=image_size)

            loaders["label"] = DataLoader(unlabeled_label,
                                          batch_size=batch_size // 2,
                                          num_workers=num_workers,
                                          pin_memory=True)

            if train_sampler is not None:
                num_samples = 2 * train_sampler.num_samples
            else:
                num_samples = 2 * len(train_ds)
            weights = compute_sample_weight("balanced", [0] * len(train_ds) +
                                            [1] * len(unlabeled_label))

            train_sampler = WeightedRandomSampler(weights,
                                                  num_samples,
                                                  replacement=True)
            train_ds = train_ds + unlabeled_train

            callbacks += [
                BCEOnlinePseudolabelingCallback2d(
                    unlabeled_train,
                    pseudolabel_loader="label",
                    prob_threshold=0.7,
                    output_key=OUTPUT_MASK_KEY,
                    unlabeled_class=UNLABELED_SAMPLE,
                    label_frequency=5,
                )
            ]

            print("Using online pseudolabeling with ", len(unlabeled_label),
                  "samples")

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True)

        # Create losses
        for loss_name, loss_weight in criterions:
            criterion_callback = CriterionCallback(
                prefix="seg_loss/" + loss_name,
                input_key=INPUT_MASK_KEY if loss_name != "wbce" else
                [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY],
                output_key=OUTPUT_MASK_KEY,
                criterion_key=loss_name,
                multiplier=float(loss_weight),
            )

            criterions_dict[loss_name] = get_loss(loss_name,
                                                  ignore_index=ignore_index)
            callbacks.append(criterion_callback)
            losses.append(criterion_callback.prefix)
            print("Using loss", loss_name, loss_weight)

        if use_dsv:
            print("Using DSV")
            criterions = "dsv"
            dsv_loss_name = "soft_bce"

            criterions_dict[criterions] = AdaptiveMaskLoss2d(
                get_loss(dsv_loss_name, ignore_index=ignore_index))

            for i, dsv_input in enumerate([
                    OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY,
                    OUTPUT_MASK_32_KEY
            ]):
                criterion_callback = CriterionCallback(
                    prefix="seg_loss_dsv/" + dsv_input,
                    input_key=OUTPUT_MASK_KEY,
                    output_key=dsv_input,
                    criterion_key=criterions,
                    multiplier=1.0,
                )
                callbacks.append(criterion_callback)
                losses.append(criterion_callback.prefix)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)):
            callbacks += [SchedulerCallback(mode="batch")]

        print("Train session    :", checkpoint_prefix)
        print("\tFP16 mode      :", fp16)
        print("\tFast mode      :", args.fast)
        print("\tTrain mode     :", train_mode)
        print("\tEpochs         :", num_epochs)
        print("\tWorkers        :", num_workers)
        print("\tData dir       :", data_dir)
        print("\tLog dir        :", log_dir)
        print("\tAugmentations  :", augmentations)
        print("\tTrain size     :", len(loaders["train"]), len(train_ds))
        print("\tValid size     :", len(loaders["valid"]), len(valid_ds))
        print("Model            :", model_name)
        print("\tParameters     :", count_parameters(model))
        print("\tImage size     :", image_size)
        print("Optimizer        :", optimizer_name)
        print("\tLearning rate  :", learning_rate)
        print("\tBatch size     :", batch_size)
        print("\tCriterion      :", criterions)
        print("\tUse weight mask:", need_weight_mask)

        # model training
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": vars(args)},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                       "best.pth")

        model_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                        f"{checkpoint_prefix}.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        unpack_checkpoint(torch.load(model_checkpoint), model=model)

        mask = predict(model,
                       read_inria_image("sample_color.jpg"),
                       image_size=image_size,
                       batch_size=args.batch_size)
        mask = ((mask > 0) * 255).astype(np.uint8)
        name = os.path.join(log_dir, "sample_color.jpg")
        cv2.imwrite(name, mask)

        del optimizer, loaders
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("model", type=str, nargs="+")
    parser.add_argument("-o", "--output-dir", type=str, default=None)
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--tta", type=str, default=None)
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=1,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-w", "--workers", type=int, default=0, help="")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default="data",
                        help="Data directory")
    parser.add_argument("--size", default=1024, type=int)
    parser.add_argument("--fold", default=None, type=int)
    parser.add_argument("--no-save", action="store_true")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--activation", default="model", type=str)
    parser.add_argument("--align", action="store_true")

    args = parser.parse_args()

    fp16 = args.fp16
    activation = args.activation

    average_score = []
    average_dmg = []
    average_loc = []

    for model_checkpoint in args.model:
        model_checkpoint = fs.auto_file(model_checkpoint)
        checkpoint = torch.load(model_checkpoint)

        print("Model        :", model_checkpoint)
        print(
            "Metrics      :",
            checkpoint["epoch_metrics"]["valid"]["weighted_f1"],
            checkpoint["epoch_metrics"]["valid"]
            ["weighted_f1/localization_f1"],
            checkpoint["epoch_metrics"]["valid"]["weighted_f1/damage_f1"],
        )

        workers = args.workers
        data_dir = args.data_dir
        fast = args.fast
        tta = args.tta
        no_save = args.no_save
        image_size = args.size or checkpoint["checkpoint_data"]["cmd_args"][
            "size"]
        batch_size = args.batch_size or checkpoint["checkpoint_data"][
            "cmd_args"]["batch_size"]
        fold = args.fold or checkpoint["checkpoint_data"]["cmd_args"]["fold"]
        align = args.align

        print("Image size :", image_size)
        print("Fold       :", fold)
        print("Align      :", align)
        print("Workers    :", workers)
        print("Save       :", not no_save)

        output_dir = None
        if not no_save:
            output_dir = args.output_dir or os.path.join(
                os.path.dirname(model_checkpoint),
                fs.id_from_fname(model_checkpoint) + "_oof_predictions")
            print("Output dir :", output_dir)

        # Load models
        model, info = model_from_checkpoint(model_checkpoint,
                                            tta=tta,
                                            activation_after=None,
                                            report=False)
        print(info)
        _, valid_ds, _ = get_datasets(data_dir=data_dir,
                                      image_size=(image_size, image_size),
                                      fast=fast,
                                      fold=fold,
                                      align_post=align)

        score, localization_f1, damage_f1, damage_f1s = run_inference_on_dataset_oof(
            model=model,
            dataset=valid_ds,
            output_dir=output_dir,
            batch_size=batch_size,
            workers=workers,
            save=not no_save,
            fp16=fp16)

        average_score.append(score)
        average_dmg.append(damage_f1)
        average_loc.append(localization_f1)

        print("Score        :", score)
        print("Localization :", localization_f1)
        print("Damage       :", damage_f1)
        print("Per class    :", damage_f1s)
        print()

    print("Average")
    if len(average_score) > 1:
        print("Score        :", np.mean(average_score), np.std(average_score))
        print("Localization :", np.mean(average_loc), np.std(average_loc))
        print("Damage       :", np.mean(average_dmg), np.std(average_dmg))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc",
                        "--accumulation-steps",
                        type=int,
                        default=1,
                        help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        required=True,
                        help="Data directory for INRIA sattelite dataset")
    parser.add_argument("-m",
                        "--model",
                        type=str,
                        default="resnet34_fpncat128",
                        help="")
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=8,
                        help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-e",
                        "--epochs",
                        type=int,
                        default=100,
                        help="Epoch to run")
    # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement')
    # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument("-lr",
                        "--learning-rate",
                        type=float,
                        default=1e-3,
                        help="Initial learning rate")
    parser.add_argument(
        "--disaster-type-loss",
        type=str,
        default=None,  # [["ce", 1.0]],
        action="append",
        nargs="+",
        help="Criterion for classifying disaster type",
    )
    parser.add_argument(
        "--damage-type-loss",
        type=str,
        default=None,  # [["bce", 1.0]],
        action="append",
        nargs="+",
        help=
        "Criterion for classifying presence of building with particular damage type",
    )

    parser.add_argument("-l",
                        "--criterion",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion")
    parser.add_argument("--mask4",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 4")
    parser.add_argument("--mask8",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 8")
    parser.add_argument("--mask16",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 16")
    parser.add_argument("--mask32",
                        type=str,
                        default=None,
                        action="append",
                        nargs="+",
                        help="Criterion for mask with stride 32")
    parser.add_argument("--embedding", type=str, default=None)

    parser.add_argument("-o",
                        "--optimizer",
                        default="RAdam",
                        help="Name of the optimizer")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        help="Checkpoint filename to use as initial model weights")
    parser.add_argument("-w",
                        "--workers",
                        default=8,
                        type=int,
                        help="Num workers")
    parser.add_argument("-a",
                        "--augmentations",
                        default="safe",
                        type=str,
                        help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--size", default=512, type=int)
    parser.add_argument("--fold", default=0, type=int)
    parser.add_argument("-s",
                        "--scheduler",
                        default="multistep",
                        type=str,
                        help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d",
                        "--dropout",
                        default=0.0,
                        type=float,
                        help="Dropout before head layer")
    parser.add_argument("-pl", "--pseudolabeling", type=str, required=True)
    parser.add_argument("-wd",
                        "--weight-decay",
                        default=0,
                        type=float,
                        help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--dsv", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--only-buildings", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")
    parser.add_argument("--crops",
                        action="store_true",
                        help="Train on random crops")
    parser.add_argument("--post-transform", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = args.size, args.size
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    segmentation_losses = args.criterion
    verbose = args.verbose
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    fold = args.fold
    balance = args.balance
    only_buildings = args.only_buildings
    freeze_bn = args.freeze_bn
    train_on_crops = args.crops
    enable_post_image_transform = args.post_transform
    disaster_type_loss = args.disaster_type_loss
    train_batch_size = args.batch_size
    embedding_criterion = args.embedding
    damage_type_loss = args.damage_type_loss
    pseudolabels_dir = args.pseudolabeling

    # Compute batch size for validaion
    if train_on_crops:
        valid_batch_size = max(1,
                               (train_batch_size *
                                (image_size[0] * image_size[1])) // (1024**2))
    else:
        valid_batch_size = train_batch_size

    run_train = num_epochs > 0

    model: nn.Module = get_model(model_name, dropout=dropout).cuda()

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    if freeze_bn:
        torch_utils.freeze_bn(model)
        print("Freezing bn params")

    runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None)
    main_metric = "weighted_f1"
    cmd_args = vars(args)

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}_{args.size}_fold{fold}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if pseudolabels_dir:
        checkpoint_prefix += "_pseudo"

    if train_on_crops:
        checkpoint_prefix += "_crops"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = [
        CompetitionMetricCallback(input_key=INPUT_MASK_KEY,
                                  output_key=OUTPUT_MASK_KEY,
                                  prefix="weighted_f1"),
        ConfusionMatrixCallback(
            input_key=INPUT_MASK_KEY,
            output_key=OUTPUT_MASK_KEY,
            class_names=[
                "land", "no_damage", "minor_damage", "major_damage",
                "destroyed"
            ],
            ignore_index=UNLABELED_SAMPLE,
        ),
    ]

    if show:
        default_callbacks += [
            ShowPolarBatchesCallback(draw_predictions,
                                     metric=main_metric + "_batch",
                                     minimize=False)
        ]

    train_ds, valid_ds, train_sampler = get_datasets(
        data_dir=data_dir,
        image_size=image_size,
        augmentation=augmentations,
        fast=fast,
        fold=fold,
        balance=balance,
        only_buildings=only_buildings,
        train_on_crops=train_on_crops,
        crops_multiplication_factor=1,
        enable_post_image_transform=enable_post_image_transform,
    )

    if run_train:
        loaders = collections.OrderedDict()
        callbacks = default_callbacks.copy()
        criterions_dict = {}
        losses = []

        unlabeled_train = get_pseudolabeling_dataset(
            data_dir,
            include_masks=True,
            image_size=image_size,
            augmentation="medium_nmd",
            train_on_crops=train_on_crops,
            enable_post_image_transform=enable_post_image_transform,
            pseudolabels_dir=pseudolabels_dir,
        )

        train_ds = train_ds + unlabeled_train

        print("Using online pseudolabeling with ", len(unlabeled_train),
              "samples")

        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=True,
        )

        loaders["valid"] = DataLoader(valid_ds,
                                      batch_size=valid_batch_size,
                                      num_workers=num_workers,
                                      pin_memory=True)

        # Create losses
        for criterion in segmentation_losses:
            if isinstance(criterion, (list, tuple)) and len(criterion) == 2:
                loss_name, loss_weight = criterion
            else:
                loss_name, loss_weight = criterion[0], 1.0

            cd, criterion, criterion_name = get_criterion_callback(
                loss_name,
                prefix="segmentation",
                input_key=INPUT_MASK_KEY,
                output_key=OUTPUT_MASK_KEY,
                loss_weight=float(loss_weight),
            )
            criterions_dict.update(cd)
            callbacks.append(criterion)
            losses.append(criterion_name)
            print(INPUT_MASK_KEY, "Using loss", loss_name, loss_weight)

        if args.mask4 is not None:
            for criterion in args.mask4:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask4",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_4_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_4_KEY, "Using loss", loss_name, loss_weight)

        if args.mask8 is not None:
            for criterion in args.mask8:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask8",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_8_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_8_KEY, "Using loss", loss_name, loss_weight)

        if args.mask16 is not None:
            for criterion in args.mask16:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask16",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_16_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_16_KEY, "Using loss", loss_name, loss_weight)

        if args.mask32 is not None:
            for criterion in args.mask32:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix="mask32",
                    input_key=INPUT_MASK_KEY,
                    output_key=OUTPUT_MASK_32_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(OUTPUT_MASK_32_KEY, "Using loss", loss_name, loss_weight)

        if disaster_type_loss is not None:
            callbacks += [
                ConfusionMatrixCallback(
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    class_names=DISASTER_TYPES,
                    ignore_index=UNKNOWN_DISASTER_TYPE_CLASS,
                    prefix=f"{DISASTER_TYPE_KEY}/confusion_matrix",
                ),
                AccuracyCallback(
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    prefix=f"{DISASTER_TYPE_KEY}/accuracy",
                    activation="Softmax",
                ),
            ]

            for criterion in disaster_type_loss:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix=DISASTER_TYPE_KEY,
                    input_key=DISASTER_TYPE_KEY,
                    output_key=DISASTER_TYPE_KEY,
                    loss_weight=float(loss_weight),
                    ignore_index=UNKNOWN_DISASTER_TYPE_CLASS,
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(DISASTER_TYPE_KEY, "Using loss", loss_name, loss_weight)

        if damage_type_loss is not None:
            callbacks += [
                # MultilabelConfusionMatrixCallback(
                #     input_key=DAMAGE_TYPE_KEY,
                #     output_key=DAMAGE_TYPE_KEY,
                #     class_names=DAMAGE_TYPES,
                #     prefix=f"{DAMAGE_TYPE_KEY}/confusion_matrix",
                # ),
                AccuracyCallback(
                    input_key=DAMAGE_TYPE_KEY,
                    output_key=DAMAGE_TYPE_KEY,
                    prefix=f"{DAMAGE_TYPE_KEY}/accuracy",
                    activation="Sigmoid",
                    threshold=0.5,
                )
            ]

            for criterion in damage_type_loss:
                if isinstance(criterion, (list, tuple)):
                    loss_name, loss_weight = criterion
                else:
                    loss_name, loss_weight = criterion, 1.0

                cd, criterion, criterion_name = get_criterion_callback(
                    loss_name,
                    prefix=DAMAGE_TYPE_KEY,
                    input_key=DAMAGE_TYPE_KEY,
                    output_key=DAMAGE_TYPE_KEY,
                    loss_weight=float(loss_weight),
                )
                criterions_dict.update(cd)
                callbacks.append(criterion)
                losses.append(criterion_name)
                print(DAMAGE_TYPE_KEY, "Using loss", loss_name, loss_weight)

        if embedding_criterion is not None:
            cd, criterion, criterion_name = get_criterion_callback(
                embedding_criterion,
                prefix="embedding",
                input_key=INPUT_MASK_KEY,
                output_key=OUTPUT_EMBEDDING_KEY,
                loss_weight=1.0,
            )
            criterions_dict.update(cd)
            callbacks.append(criterion)
            losses.append(criterion_name)
            print(OUTPUT_EMBEDDING_KEY, "Using loss", embedding_criterion)

        callbacks += [
            CriterionAggregatorCallback(prefix="loss", loss_keys=losses),
            OptimizerCallback(accumulation_steps=accumulation_steps,
                              decouple_weight_decay=False),
        ]

        optimizer = get_optimizer(optimizer_name,
                                  get_optimizable_parameters(model),
                                  learning_rate,
                                  weight_decay=weight_decay)
        scheduler = get_scheduler(scheduler_name,
                                  optimizer,
                                  lr=learning_rate,
                                  num_epochs=num_epochs,
                                  batches_in_epoch=len(loaders["train"]))
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("Data             ")
        print("  Augmentations  :", augmentations)
        print("  Train size     :", len(loaders["train"]), len(train_ds))
        print("  Valid size     :", len(loaders["valid"]), len(valid_ds))
        print("  Image size     :", image_size)
        print("  Train on crops :", train_on_crops)
        print("  Balance        :", balance)
        print("  Buildings only :", only_buildings)
        print("  Post transform :", enable_post_image_transform)
        print("  Pseudolabels   :", pseudolabels_dir)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("  Criterion      :", segmentation_losses)
        print("  Damage type    :", damage_type_loss)
        print("  Disaster type  :", disaster_type_loss)
        print(" Embedding      :", embedding_criterion)

        # model training
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "opl"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=False,
            checkpoint_data={"cmd_args": cmd_args},
        )

        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                       "best.pth")

        model_checkpoint = os.path.join(log_dir, "main", "checkpoints",
                                        f"{checkpoint_prefix}.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        del optimizer, loaders
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("models", nargs="+")
    parser.add_argument("-o", "--output-dir", type=str)
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--tta", type=str, default=None)
    parser.add_argument("-b", "--batch-size", type=int, default=1, help="Batch Size during training, e.g. -b 64")
    parser.add_argument("-w", "--workers", type=int, default=0, help="")
    parser.add_argument("-dd", "--data-dir", type=str, default="data", help="Data directory")
    parser.add_argument("-p", "--postprocessing", type=str, default="dominant")
    parser.add_argument("--size", default=1024, type=int)
    parser.add_argument("--activation", default="model", type=str)
    parser.add_argument("--weights", default=None, type=float, nargs="+")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--align", action="store_true")

    args = parser.parse_args()

    workers = args.workers
    data_dir = args.data_dir
    fast = args.fast
    tta = args.tta
    image_size = args.size, args.size
    model_checkpoints = args.models
    batch_size = args.batch_size
    activation_after = args.activation
    fp16 = args.fp16
    align = args.align
    postprocessing=args.postprocessing
    weights = args.weights
    assert weights is None or len(weights) == 5

    current_time = datetime.now().strftime("%b%d_%H_%M")
    if args.output_dir is None and len(model_checkpoints) == 1:
        output_dir = os.path.join(
            os.path.dirname(model_checkpoints[0]), fs.id_from_fname(model_checkpoints[0]) + "_test_predictions"
        )
        if weights is not None:
            output_dir += "_weighted"
        if tta is not None:
            output_dir += f"_{tta}"
    else:
        output_dir = args.output_dir or f"output_dir_{current_time}"

    print("Size",       image_size)
    print("Output dir", output_dir)
    print("Postproc  ", postprocessing)
    # Load models

    models = []
    infos = []
    for model_checkpoint in model_checkpoints:
        try:
            model, info = model_from_checkpoint(
                fs.auto_file(model_checkpoint), tta=None, activation_after=activation_after, report=False
            )
            models.append(model)
            infos.append(info)
        except Exception as e:
            print(e)
            print(model_checkpoint)
            return

    df = pd.DataFrame.from_records(infos)
    print(df)

    print("score        ", df["score"].mean(), df["score"].std())
    print("localization ", df["localization"].mean(), df["localization"].std())
    print("damage       ", df["damage"].mean(), df["damage"].std())

    if len(models) > 1:
        model = Ensembler(models, [OUTPUT_MASK_KEY])
        if activation_after == "ensemble":
            model = ApplySoftmaxTo(model, OUTPUT_MASK_KEY)
            print("Applying activation after ensemble")

        if tta == "multiscale":
            print(f"Using {tta}")
            model = MultiscaleTTA(model, outputs=[OUTPUT_MASK_KEY], size_offsets=[-128, +128], average=True)

        if tta == "flip":
            print(f"Using {tta}")
            model = HFlipTTA(model, outputs=[OUTPUT_MASK_KEY], average=True)

        if tta == "flipscale":
            print(f"Using {tta}")
            model = HFlipTTA(model, outputs=[OUTPUT_MASK_KEY], average=True)
            model = MultiscaleTTA(model, outputs=[OUTPUT_MASK_KEY], size_offsets=[-128, +128], average=True)

        if tta == "multiscale_d4":
            print(f"Using {tta}")
            model = D4TTA(model, outputs=[OUTPUT_MASK_KEY], average=True)
            model = MultiscaleTTA(model, outputs=[OUTPUT_MASK_KEY], size_offsets=[-128, +128], average=True)

        if activation_after == "tta":
            model = ApplySoftmaxTo(model, OUTPUT_MASK_KEY)

    else:
        model = models[0]

    test_ds = get_test_dataset(data_dir=data_dir, image_size=image_size, fast=fast, align_post=align)

    run_inference_on_dataset(
        model=model,
        dataset=test_ds,
        output_dir=output_dir,
        batch_size=batch_size,
        workers=workers,
        weights=weights,
        fp16=fp16,
        postprocessing=postprocessing,
    )
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoints", nargs="+")
    parser.add_argument("-w", "--workers", type=int, default=1, help="")
    parser.add_argument("-dd", "--data-dir", type=str, default="data", help="Data directory")
    parser.add_argument("-a", "--activation", type=str, default="pre", help="")
    args = parser.parse_args()

    targets = fs.find_in_dir(os.path.join(args.data_dir, "tier3", "masks")) + fs.find_in_dir(
        os.path.join(args.data_dir, "train", "masks")
    )
    targets_post = dict((fs.id_from_fname(fname), fname) for fname in targets if "_post_" in fname)

    df = defaultdict(list)

    current_time = datetime.now().strftime("%b%d_%H_%M")

    print("Checkpoints ", args.checkpoints)
    print("Activation  ", args.activation)

    for model_checkpoint in args.checkpoints:
        model_checkpoint = fs.auto_file(model_checkpoint)
        predictions_dir = os.path.join(
            os.path.dirname(model_checkpoint), fs.id_from_fname(model_checkpoint) + "_oof_predictions"
        )

        prediction_files = fs.find_in_dir(predictions_dir)
        prediction_files_post = dict(
            (fs.id_from_fname(fname), fname) for fname in prediction_files if "_post_" in fname
        )

        y_true_filenames = [targets_post[image_id_post] for image_id_post in prediction_files_post.keys()]
        y_pred_filenames = [prediction_files_post[image_id_post] for image_id_post in prediction_files_post.keys()]

        rounder = OptimizedRounder(workers=args.workers, apply_softmax=args.activation)

        raw_score, raw_localization_f1, raw_damage_f1, raw_damage_f1s = rounder.predict(
            y_pred_filenames, y_true_filenames, np.array([1, 1, 1, 1, 1], dtype=np.float32)
        )

        rounder.fit(y_pred_filenames, y_true_filenames)

        score, localization_f1, damage_f1, damage_f1s = rounder.predict(
            y_pred_filenames, y_true_filenames, rounder.coefficients()
        )

        print(rounder.coefficients())

        df["checkpoint"].append(fs.id_from_fname(model_checkpoint))
        df["coefficients"].append(rounder.coefficients())
        df["samples"].append(len(y_true_filenames))

        df["raw_score"].append(raw_score)
        df["raw_localization"].append(raw_localization_f1)
        df["raw_damage"].append(raw_damage_f1)

        df["opt_score"].append(score)
        df["opt_localization"].append(localization_f1)
        df["opt_damage"].append(damage_f1)

        dataframe = pd.DataFrame.from_dict(df)
        dataframe.to_csv(f"optimized_weights_{current_time}.csv", index=None)
        print(df)
def main():
    f0_aptos15 = pd.read_pickle(
        fs.auto_file(
            'reg_seresnext50_rms_512_medium_mse_aptos2019_fold0_awesome_babbage_on_aptos2019_fold0.pkl'
        ))
    f0_aptos15['fold'] = 0

    f0_idrid = pd.read_pickle(
        fs.auto_file(
            'reg_seresnext50_rms_512_medium_mse_idrid_fold0_heuristic_ptolemy_on_aptos2019_fold0.pkl'
        ))
    f0_idrid['fold'] = 0

    f1_aptos15 = pd.read_pickle(
        fs.auto_file(
            'reg_seresnext50_rms_512_medium_mse_aptos2019_fold1_hopeful_khorana_on_aptos2019_fold1.pkl'
        ))
    f1_aptos15['fold'] = 1

    f1_idrid = pd.read_pickle(
        fs.auto_file(
            'reg_seresnext50_rms_512_medium_mse_idrid_fold1_gifted_visvesvaraya_on_aptos2019_fold1.pkl'
        ))
    f1_idrid['fold'] = 1

    f2_aptos15 = pd.read_pickle(
        fs.auto_file(
            'reg_seresnext50_rms_512_medium_mse_aptos2019_fold2_trusting_nightingale_on_aptos2019_fold2.pkl'
        ))
    f2_aptos15['fold'] = 2

    f2_idrid = pd.read_pickle(
        fs.auto_file(
            'reg_seresnext50_rms_512_medium_mse_idrid_fold2_sharp_brattain_on_aptos2019_fold2.pkl'
        ))
    f2_idrid['fold'] = 2

    f3_aptos15 = pd.read_pickle(
        fs.auto_file(
            'reg_seresnext50_rms_512_medium_mse_aptos2019_fold3_epic_wing_on_aptos2019_fold3.pkl'
        ))
    f3_aptos15['fold'] = 3

    f3_idrid = pd.read_pickle(
        fs.auto_file(
            'reg_seresnext50_rms_512_medium_mse_idrid_fold3_vibrant_minsky_on_aptos2019_fold3.pkl'
        ))
    f3_idrid['fold'] = 3

    df_aptos15 = pd.concat([f0_aptos15, f1_aptos15, f2_aptos15, f3_aptos15])
    df_idrid = pd.concat([f0_idrid, f1_idrid, f2_idrid, f3_idrid])

    print(len(f0_aptos15), len(f1_aptos15), len(f2_aptos15), len(f3_aptos15),
          len(df_aptos15))

    # logits = np.array(df_aptos15['logits'].values.tolist())
    regression = np.array(df_aptos15['regression'].values.tolist())

    X = np.hstack((np.array(df_aptos15['features'].values.tolist()),
                   np.array(df_idrid['features'].values.tolist())))

    Y = np.array(df_aptos15['diagnosis_true'].values.tolist())

    print(X.shape, Y.shape)

    x_train, x_test, y_train, y_test, y_hat_train, y_hat_test = train_test_split(
        X,
        Y,
        to_numpy(regression_to_class(regression)),
        stratify=Y,
        test_size=0.25,
        random_state=0)

    from sklearn.preprocessing import StandardScaler
    sc = StandardScaler()
    x_train = sc.fit_transform(x_train)
    x_test = sc.transform(x_test)

    # d_train = lgb.Dataset(x_train, label=y_train.astype(np.float32))
    #
    # params = {}
    # params['learning_rate'] = 0.003
    # params['boosting_type'] = 'gbdt'
    # params['objective'] = 'regression'
    # params['metric'] = 'mse'
    # params['sub_feature'] = 0.5
    # params['num_leaves'] = 10
    # params['min_data'] = 50
    # # params['max_depth'] = 4
    #
    # clf = lgb.train(params, d_train, 1000)
    #
    # y_pred = clf.predict(x_test)
    # y_pred_class = regression_to_class(y_pred)

    cls = LGBMClassifier(n_estimators=64,
                         max_depth=10,
                         random_state=0,
                         num_leaves=256)
    cls.fit(x_train, y_train)
    y_pred_class = cls.predict(x_test)

    raw_score, num, denom = cohen_kappa_score(y_hat_test,
                                              y_test,
                                              weights='quadratic')
    print('raw_score', raw_score)
    lgb_score, num, denom = cohen_kappa_score(y_pred_class,
                                              y_test,
                                              weights='quadratic')
    print('lgb_score', lgb_score)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='data',
                        help='Data directory for INRIA sattelite dataset')
    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='cls_resnet18',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=100,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    parser.add_argument('-fe', '--freeze-encoder', action='store_true')
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('-l',
                        '--criterion',
                        type=str,
                        default='bce',
                        help='Criterion')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-w',
                        '--workers',
                        default=multiprocessing.cpu_count(),
                        type=int,
                        help='Num workers')
    parser.add_argument('-a',
                        '--augmentations',
                        default='hard',
                        type=str,
                        help='')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    parser.add_argument('-tm',
                        '--train-mode',
                        default='random',
                        type=str,
                        help='')
    parser.add_argument('-rm',
                        '--run-mode',
                        default='fit_predict',
                        type=str,
                        help='')
    parser.add_argument('--transfer', default=None, type=str, help='')
    parser.add_argument('--fp16', action='store_true')

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    early_stopping = args.early_stopping
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = (512, 512)
    fast = args.fast
    augmentations = args.augmentations
    train_mode = args.train_mode
    run_mode = args.run_mode
    log_dir = None
    fp16 = args.fp16
    freeze_encoder = args.freeze_encoder

    run_train = run_mode == 'fit_predict' or run_mode == 'fit'
    run_predict = run_mode == 'fit_predict' or run_mode == 'predict'

    model = maybe_cuda(get_model(model_name, num_classes=1))

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transfering weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint['model_state_dict']

        for name, value in pretrained_dict.items():
            try:
                model.load_state_dict(collections.OrderedDict([(name, value)]),
                                      strict=False)
            except Exception as e:
                print(e)

    checkpoint = None
    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        checkpoint_epoch = checkpoint['epoch']
        print('Loaded model weights from:', args.checkpoint)
        print('Epoch                    :', checkpoint_epoch)
        print('Metrics (Train):', 'f1  :',
              checkpoint['epoch_metrics']['train']['f1_score'], 'loss:',
              checkpoint['epoch_metrics']['train']['loss'])
        print('Metrics (Valid):', 'f1  :',
              checkpoint['epoch_metrics']['valid']['f1_score'], 'loss:',
              checkpoint['epoch_metrics']['valid']['loss'])

        log_dir = os.path.dirname(
            os.path.dirname(fs.auto_file(args.checkpoint)))

    if run_train:

        if freeze_encoder:
            set_trainable(model.encoder, trainable=False, freeze_bn=True)

        criterion = get_loss(args.criterion)
        parameters = get_optimizable_parameters(model)
        optimizer = get_optimizer(optimizer_name, parameters, learning_rate)

        if checkpoint is not None:
            try:
                unpack_checkpoint(checkpoint, optimizer=optimizer)
                print('Restored optimizer state from checkpoint')
            except Exception as e:
                print('Failed to restore optimizer state from checkpoint', e)

        train_loader, valid_loader = get_dataloaders(
            data_dir=data_dir,
            batch_size=batch_size,
            num_workers=num_workers,
            image_size=image_size,
            augmentation=augmentations,
            fast=fast)

        loaders = collections.OrderedDict()
        loaders["train"] = train_loader
        loaders["valid"] = valid_loader

        current_time = datetime.now().strftime('%b%d_%H_%M')
        prefix = f'adversarial/{args.model}/{current_time}_{args.criterion}'

        if fp16:
            prefix += '_fp16'

        if fast:
            prefix += '_fast'

        log_dir = os.path.join('runs', prefix)
        os.makedirs(log_dir, exist_ok=False)

        scheduler = MultiStepLR(optimizer,
                                milestones=[10, 30, 50, 70, 90],
                                gamma=0.5)

        print('Train session    :', prefix)
        print('\tFP16 mode      :', fp16)
        print('\tFast mode      :', args.fast)
        print('\tTrain mode     :', train_mode)
        print('\tEpochs         :', num_epochs)
        print('\tEarly stopping :', early_stopping)
        print('\tWorkers        :', num_workers)
        print('\tData dir       :', data_dir)
        print('\tLog dir        :', log_dir)
        print('\tAugmentations  :', augmentations)
        print('\tTrain size     :', len(train_loader),
              len(train_loader.dataset))
        print('\tValid size     :', len(valid_loader),
              len(valid_loader.dataset))
        print('Model            :', model_name)
        print('\tParameters     :', count_parameters(model))
        print('\tImage size     :', image_size)
        print('\tFreeze encoder :', freeze_encoder)
        print('Optimizer        :', optimizer_name)
        print('\tLearning rate  :', learning_rate)
        print('\tBatch size     :', batch_size)
        print('\tCriterion      :', args.criterion)

        # model training
        visualization_fn = partial(draw_classification_predictions,
                                   class_names=['Train', 'Test'])

        callbacks = [
            F1ScoreCallback(),
            AUCCallback(),
            ShowPolarBatchesCallback(visualization_fn,
                                     metric='f1_score',
                                     minimize=False),
        ]

        if early_stopping:
            callbacks += [
                EarlyStoppingCallback(early_stopping,
                                      metric='auc',
                                      minimize=False)
            ]

        runner = SupervisedRunner(input_key='image')
        runner.train(fp16=fp16,
                     model=model,
                     criterion=criterion,
                     optimizer=optimizer,
                     scheduler=scheduler,
                     callbacks=callbacks,
                     loaders=loaders,
                     logdir=log_dir,
                     num_epochs=num_epochs,
                     verbose=True,
                     main_metric='auc',
                     minimize_metric=False,
                     state_kwargs={"cmd_args": vars(args)})

    if run_predict and not fast:
        # Training is finished. Let's run predictions using best checkpoint weights
        best_checkpoint = load_checkpoint(
            fs.auto_file('best.pth', where=log_dir))
        unpack_checkpoint(best_checkpoint, model=model)

        model.eval()
        torch.no_grad()

        train_csv = pd.read_csv(os.path.join(data_dir, 'train.csv'))
        train_csv['id_code'] = train_csv['id_code'].apply(
            lambda x: os.path.join(data_dir, 'train_images', f'{x}.png'))
        test_ds = RetinopathyDataset(train_csv['id_code'],
                                     None,
                                     get_test_aug(image_size),
                                     target_as_array=True)
        test_dl = DataLoader(test_ds,
                             batch_size,
                             pin_memory=True,
                             num_workers=num_workers)

        test_ids = []
        test_preds = []

        for batch in tqdm(test_dl, desc='Inference'):
            input = batch['image'].cuda()
            outputs = model(input)
            predictions = to_numpy(outputs['logits'].sigmoid().squeeze(1))
            test_ids.extend(batch['image_id'])
            test_preds.extend(predictions)

        df = pd.DataFrame.from_dict({
            'id_code': test_ids,
            'is_test': test_preds
        })
        df.to_csv(os.path.join(log_dir, 'test_in_train.csv'), index=None)
Esempio n. 21
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--obliterate", type=float, default=0, help="Change of obliteration")
    parser.add_argument("-nid", "--negative-image-dir", type=str, default=None, help="Change of obliteration")
    parser.add_argument("-v", "--verbose", action="store_true")
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--cache", action="store_true")
    parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-m", "--model", type=str, default="resnet34", help="")
    parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch Size during training, e.g. -b 64")
    parser.add_argument(
        "-wbs", "--warmup-batch-size", type=int, default=None, help="Batch Size during training, e.g. -b 64"
    )
    parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run")
    parser.add_argument(
        "-es", "--early-stopping", type=int, default=None, help="Maximum number of epochs without improvement"
    )
    parser.add_argument("-fe", "--freeze-encoder", action="store_true", help="Freeze encoder parameters for N epochs")
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate")

    parser.add_argument(
        "-l", "--modification-flag-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument(
        "--modification-type-loss", type=str, default=None, action="append", nargs="+"  # [["ce", 1.0]],
    )
    parser.add_argument("--embedding-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--feature-maps-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--mask-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],
    parser.add_argument("--bits-loss", type=str, default=None, action="append", nargs="+")  # [["ce", 1.0]],

    parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer")
    parser.add_argument(
        "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights"
    )
    parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers")
    parser.add_argument("-a", "--augmentations", default="safe", type=str, help="Level of image augmentations")
    parser.add_argument("--transfer", default=None, type=str, help="")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--mixup", action="store_true")
    parser.add_argument("--cutmix", action="store_true")
    parser.add_argument("--tsa", action="store_true")
    parser.add_argument("--fold", default=None, type=int)
    parser.add_argument("-s", "--scheduler", default=None, type=str, help="")
    parser.add_argument("-x", "--experiment", default=None, type=str, help="")
    parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer")
    parser.add_argument(
        "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument(
        "--fine-tune", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters"
    )
    parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay")
    parser.add_argument("--show", action="store_true")
    parser.add_argument("--balance", action="store_true")
    parser.add_argument("--freeze-bn", action="store_true")

    args = parser.parse_args()
    set_manual_seed(args.seed)

    assert (
        args.modification_flag_loss or args.modification_type_loss or args.embedding_loss
    ), "At least one of losses must be set"

    modification_flag_loss = args.modification_flag_loss
    modification_type_loss = args.modification_type_loss
    embedding_loss = args.embedding_loss
    feature_maps_loss = args.feature_maps_loss
    mask_loss = args.mask_loss
    bits_loss = args.bits_loss

    freeze_encoder = args.freeze_encoder
    data_dir = args.data_dir
    cache = args.cache
    num_workers = args.workers
    num_epochs = args.epochs
    learning_rate = args.learning_rate
    model_name: str = args.model
    optimizer_name = args.optimizer
    image_size = (512, 512)
    fast = args.fast
    augmentations = args.augmentations
    fp16 = args.fp16
    scheduler_name = args.scheduler
    experiment = args.experiment
    dropout = args.dropout
    verbose = args.verbose
    warmup = args.warmup
    show = args.show
    accumulation_steps = args.accumulation_steps
    weight_decay = args.weight_decay
    fold = args.fold
    balance = args.balance
    freeze_bn = args.freeze_bn
    train_batch_size = args.batch_size
    mixup = args.mixup
    cutmix = args.cutmix
    tsa = args.tsa
    fine_tune = args.fine_tune
    obliterate_p = args.obliterate
    negative_image_dir = args.negative_image_dir
    warmup_batch_size = args.warmup_batch_size or args.batch_size

    # Compute batch size for validation
    valid_batch_size = train_batch_size
    run_train = num_epochs > 0

    custom_model_kwargs = {}
    if dropout is not None:
        custom_model_kwargs["dropout"] = float(dropout)

    if embedding_loss is not None:
        custom_model_kwargs["need_embedding"] = True

    model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda()
    required_features = model.required_features

    if mask_loss is not None:
        required_features.append(INPUT_TRUE_MODIFICATION_MASK)

    if args.transfer:
        transfer_checkpoint = fs.auto_file(args.transfer)
        print("Transferring weights from model checkpoint", transfer_checkpoint)
        checkpoint = load_checkpoint(transfer_checkpoint)
        pretrained_dict = checkpoint["model_state_dict"]

        transfer_weights(model, pretrained_dict)

    if args.checkpoint:
        checkpoint = load_checkpoint(fs.auto_file(args.checkpoint))
        unpack_checkpoint(checkpoint, model=model)

        print("Loaded model weights from:", args.checkpoint)
        report_checkpoint(checkpoint)

    if freeze_bn:
        from pytorch_toolbelt.optimization.functional import freeze_model

        freeze_model(model, freeze_bn=True)
        print("Freezing bn params")

    main_metric = "loss"
    main_metric_minimize = True

    current_time = datetime.now().strftime("%b%d_%H_%M")
    checkpoint_prefix = f"{current_time}_{args.model}_fold{fold}"

    if fp16:
        checkpoint_prefix += "_fp16"

    if fast:
        checkpoint_prefix += "_fast"

    if mixup:
        checkpoint_prefix += "_mixup"

    if cutmix:
        checkpoint_prefix += "_cutmix"

    if experiment is not None:
        checkpoint_prefix = experiment

    log_dir = os.path.join("runs", checkpoint_prefix)
    os.makedirs(log_dir, exist_ok=False)

    config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json")
    with open(config_fname, "w") as f:
        train_session_args = vars(args)
        f.write(json.dumps(train_session_args, indent=2))

    default_callbacks = []

    if show:
        default_callbacks += [ShowPolarBatchesCallback(draw_predictions, metric="loss", minimize=True)]

    # Pretrain/warmup
    if warmup:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation=augmentations,
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=0,
        )

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            feature_maps_loss=feature_maps_loss,
            num_epochs=warmup,
            mixup=mixup,
            cutmix=cutmix,
            tsa=tsa,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=warmup_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=warmup_batch_size, num_workers=num_workers, pin_memory=True)

        if freeze_encoder:
            from pytorch_toolbelt.optimization.functional import freeze_model

            freeze_model(model.encoder, freeze_parameters=True, freeze_bn=None)

        optimizer = get_optimizer(
            "Ranger", get_optimizable_parameters(model), weight_decay=weight_decay, learning_rate=3e-4
        )
        scheduler = None

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout, "(Non-default)" if dropout is not None else "")
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "warmup"),
            num_epochs=warmup,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        del optimizer, loaders, runner, callbacks

        best_checkpoint = os.path.join(log_dir, "warmup", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_warmup.pth")
        clean_checkpoint(best_checkpoint, model_checkpoint)

        # Restore state of best model
        # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        torch.cuda.empty_cache()
        gc.collect()

    if run_train:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation=augmentations,
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=obliterate_p,
        )

        if negative_image_dir:
            negatives_ds = get_negatives_ds(
                negative_image_dir, fold=fold, features=required_features, max_images=16536
            )
            train_ds = train_ds + negatives_ds
            train_sampler = None  # TODO: Add proper support of sampler
            print("Adding", len(negatives_ds), "negative samples to training set")

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            feature_maps_loss=feature_maps_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            num_epochs=num_epochs,
            mixup=mixup,
            cutmix=cutmix,
            tsa=tsa,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Obliterate (%) :", obliterate_p)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        optimizer = get_optimizer(
            optimizer_name, get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        # model training
        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "main"),
            num_epochs=num_epochs,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        del optimizer, loaders, runner, callbacks

        best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth")

        # Restore state of best model
        clean_checkpoint(best_checkpoint, model_checkpoint)
        # unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        torch.cuda.empty_cache()
        gc.collect()

    if fine_tune:
        train_ds, valid_ds, train_sampler = get_datasets(
            data_dir=data_dir,
            augmentation="light",
            balance=balance,
            fast=fast,
            fold=fold,
            features=required_features,
            obliterate_p=obliterate_p,
        )

        criterions_dict, loss_callbacks = get_criterions(
            modification_flag=modification_flag_loss,
            modification_type=modification_type_loss,
            embedding_loss=embedding_loss,
            feature_maps_loss=feature_maps_loss,
            mask_loss=mask_loss,
            bits_loss=bits_loss,
            num_epochs=fine_tune,
            mixup=False,
            cutmix=False,
            tsa=False,
        )

        callbacks = (
            default_callbacks
            + loss_callbacks
            + [
                OptimizerCallback(accumulation_steps=accumulation_steps, decouple_weight_decay=False),
                HyperParametersCallback(
                    hparam_dict={
                        "model": model_name,
                        "scheduler": scheduler_name,
                        "optimizer": optimizer_name,
                        "augmentations": augmentations,
                        "size": image_size[0],
                        "weight_decay": weight_decay,
                    }
                ),
            ]
        )

        loaders = collections.OrderedDict()
        loaders["train"] = DataLoader(
            train_ds,
            batch_size=train_batch_size,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True,
            shuffle=train_sampler is None,
            sampler=train_sampler,
        )

        loaders["valid"] = DataLoader(valid_ds, batch_size=valid_batch_size, num_workers=num_workers, pin_memory=True)

        print("Train session    :", checkpoint_prefix)
        print("  FP16 mode      :", fp16)
        print("  Fast mode      :", args.fast)
        print("  Epochs         :", num_epochs)
        print("  Workers        :", num_workers)
        print("  Data dir       :", data_dir)
        print("  Log dir        :", log_dir)
        print("  Cache          :", cache)
        print("Data              ")
        print("  Augmentations  :", augmentations)
        print("  Obliterate (%) :", obliterate_p)
        print("  Negative images:", negative_image_dir)
        print("  Train size     :", len(loaders["train"]), "batches", len(train_ds), "samples")
        print("  Valid size     :", len(loaders["valid"]), "batches", len(valid_ds), "samples")
        print("  Image size     :", image_size)
        print("  Balance        :", balance)
        print("  Mixup          :", mixup)
        print("  CutMix         :", cutmix)
        print("  TSA            :", tsa)
        print("Model            :", model_name)
        print("  Parameters     :", count_parameters(model))
        print("  Dropout        :", dropout)
        print("Optimizer        :", optimizer_name)
        print("  Learning rate  :", learning_rate)
        print("  Weight decay   :", weight_decay)
        print("  Scheduler      :", scheduler_name)
        print("  Batch sizes    :", train_batch_size, valid_batch_size)
        print("Losses            ")
        print("  Flag           :", modification_flag_loss)
        print("  Type           :", modification_type_loss)
        print("  Embedding      :", embedding_loss)
        print("  Feature maps   :", feature_maps_loss)
        print("  Mask           :", mask_loss)
        print("  Bits           :", bits_loss)

        optimizer = get_optimizer(
            "SGD", get_optimizable_parameters(model), learning_rate=learning_rate, weight_decay=weight_decay
        )
        scheduler = get_scheduler(
            "cos", optimizer, lr=learning_rate, num_epochs=fine_tune, batches_in_epoch=len(loaders["train"])
        )
        if isinstance(scheduler, CyclicLR):
            callbacks += [SchedulerCallback(mode="batch")]

        # model training
        runner = SupervisedRunner(input_key=required_features, output_key=None)
        runner.train(
            fp16=fp16,
            model=model,
            criterion=criterions_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders=loaders,
            logdir=os.path.join(log_dir, "finetune"),
            num_epochs=fine_tune,
            verbose=verbose,
            main_metric=main_metric,
            minimize_metric=main_metric_minimize,
            checkpoint_data={"cmd_args": vars(args)},
        )

        best_checkpoint = os.path.join(log_dir, "finetune", "checkpoints", "best.pth")
        model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}_finetune.pth")

        clean_checkpoint(best_checkpoint, model_checkpoint)
        unpack_checkpoint(load_checkpoint(model_checkpoint), model=model)

        del optimizer, loaders, runner, callbacks
Esempio n. 22
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", type=str, default="unet", help="")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default=None,
                        required=True,
                        help="Data dir")
    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default=None,
        required=True,
        help="Checkpoint filename to use as initial model weights",
    )
    parser.add_argument("-b",
                        "--batch-size",
                        type=int,
                        default=16,
                        help="Batch size for inference")
    parser.add_argument("-tta",
                        "--tta",
                        default=None,
                        type=str,
                        help="Type of TTA to use [fliplr, d4]")
    args = parser.parse_args()

    data_dir = args.data_dir
    checkpoint_file = auto_file(args.checkpoint)
    run_dir = os.path.dirname(os.path.dirname(checkpoint_file))
    out_dir = os.path.join(run_dir, "evaluation")
    os.makedirs(out_dir, exist_ok=True)

    model = get_model(args.model)

    checkpoint = UtilsFactory.load_checkpoint(checkpoint_file)
    checkpoint_epoch = checkpoint["epoch"]
    print("Loaded model weights from", args.checkpoint)
    print("Epoch   :", checkpoint_epoch)
    print(
        "Metrics (Train):",
        "IoU:",
        checkpoint["epoch_metrics"]["train"]["jaccard"],
        "Acc:",
        checkpoint["epoch_metrics"]["train"]["accuracy"],
    )
    print(
        "Metrics (Valid):",
        "IoU:",
        checkpoint["epoch_metrics"]["valid"]["jaccard"],
        "Acc:",
        checkpoint["epoch_metrics"]["valid"]["accuracy"],
    )

    UtilsFactory.unpack_checkpoint(checkpoint, model=model)

    model = model.cuda().eval()

    train_images = find_in_dir(os.path.join(data_dir, "train", "images"))
    for fname in tqdm(train_images, total=len(train_images)):
        image = read_inria_rgb(fname)
        mask = predict(model,
                       image,
                       tta=args.tta,
                       image_size=(512, 512),
                       batch_size=args.batch_size,
                       activation="sigmoid")
        mask = (mask * 255).astype(np.uint8)
        name = os.path.join(out_dir, os.path.basename(fname))
        cv2.imwrite(name, mask)