Ejemplo n.º 1
0
def run_dual_inference_on_dataset_oof(model,
                                      dataset,
                                      output_dir,
                                      batch_size=1,
                                      workers=0):
    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model = model.eval()

    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             pin_memory=True,
                             num_workers=workers)

    os.makedirs(output_dir, exist_ok=True)

    for batch in tqdm(data_loader):
        image_pre = batch[INPUT_IMAGE_PRE_KEY].cuda(non_blocking=True)
        image_post = batch[INPUT_IMAGE_POST_KEY].cuda(non_blocking=True)
        image_ids = batch[INPUT_IMAGE_ID_KEY]

        output = model(image_pre=image_pre, image_post=image_post)

        masks_pre = output[OUTPUT_MASK_PRE_KEY]
        if masks_pre.size(2) != 1024 or masks_pre.size(3) != 1024:
            masks_pre = F.interpolate(masks_pre,
                                      size=(1024, 1024),
                                      mode="bilinear",
                                      align_corners=False)
        masks_pre = to_numpy(masks_pre.squeeze(1)).astype(np.float32)

        masks_post = output[OUTPUT_MASK_POST_KEY]
        if masks_post.size(2) != 1024 or masks_post.size(3) != 1024:
            masks_post = F.interpolate(masks_post,
                                       size=(1024, 1024),
                                       mode="bilinear",
                                       align_corners=False)
        masks_post = to_numpy(masks_post).astype(np.float32)

        for i, image_id in enumerate(image_ids):
            localization_image = masks_pre[i]
            damage_image = masks_post[i]

            localization_fname = os.path.join(
                output_dir, fs.change_extension(image_id, ".npy"))
            np.save(localization_fname, localization_image)

            damage_fname = os.path.join(
                output_dir,
                fs.change_extension(image_id.replace("_pre", "_post"), ".npy"))
            np.save(damage_fname, damage_image)

    del data_loader
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default=os.environ.get("KAGGLE_2020_ALASKA2"))

    args = parser.parse_args()
    data_dir = args.data_dir

    cover = fs.find_images_in_dir(os.path.join(data_dir, "Cover"))
    jimi = fs.find_images_in_dir(os.path.join(data_dir, "JMiPOD"))
    juni = fs.find_images_in_dir(os.path.join(data_dir, "JUNIWARD"))
    uerd = fs.find_images_in_dir(os.path.join(data_dir, "UERD"))

    for cover_fname, jimi_fname, juni_fname, uerd_fname in zip(
            tqdm(cover), jimi, juni, uerd):
        cover = decode_bgr_from_dct(fs.change_extension(cover_fname, ".npz"))
        jimi = decode_bgr_from_dct(fs.change_extension(jimi_fname, ".npz"))
        juni = decode_bgr_from_dct(fs.change_extension(juni_fname, ".npz"))
        uerd = decode_bgr_from_dct(fs.change_extension(uerd_fname, ".npz"))

        jimi_mask = block8_sum(np.abs(cover - jimi).sum(axis=2)) > 0
        juni_mask = block8_sum(np.abs(cover - juni).sum(axis=2)) > 0
        uerd_mask = block8_sum(np.abs(cover - uerd).sum(axis=2)) > 0

        cover_mask = jimi_mask | juni_mask | uerd_mask

        cv2.imwrite(fs.change_extension(cover_fname, ".png"), cover_mask * 255)
        cv2.imwrite(fs.change_extension(jimi_fname, ".png"), jimi_mask * 255)
        cv2.imwrite(fs.change_extension(juni_fname, ".png"), juni_mask * 255)
        cv2.imwrite(fs.change_extension(uerd_fname, ".png"), uerd_mask * 255)
Ejemplo n.º 3
0
def compute_mean_std(dataset):
    """
    https://stats.stackexchange.com/questions/25848/how-to-sum-a-standard-deviation
    """

    global_mean = np.zeros(3, dtype=np.float64)
    global_var = np.zeros(3, dtype=np.float64)

    n_items = 0

    for image_fname in dataset:
        dct_file = np.load(fs.change_extension(image_fname, ".npz"))
        # This normalization roughly puts values into zero mean and unit variance
        y = idct8v2(dct_file["dct_y"])
        cb = idct8v2(dct_file["dct_cb"])
        cr = idct8v2(dct_file["dct_cr"])

        global_mean[0] += y.mean()
        global_mean[1] += cb.mean()
        global_mean[2] += cr.mean()

        global_var[0] += y.std()**2
        global_var[1] += cb.std()**2
        global_var[2] += cr.std()**2

        n_items += 1

    return global_mean / n_items, np.sqrt(global_var / n_items)
def compute_mean_std(dataset):
    """
    https://stats.stackexchange.com/questions/25848/how-to-sum-a-standard-deviation
    """

    # global_mean = np.zeros((3 * 64), dtype=np.float64)
    # global_var = np.zeros((3 * 64), dtype=np.float64)

    n_items = 0
    s = RunningStatistics()

    for image_fname in dataset:
        dct_file = np.load(fs.change_extension(image_fname, ".npz"))
        y = torch.from_numpy(dct_file["dct_y"])
        cb = torch.from_numpy(dct_file["dct_cb"])
        cr = torch.from_numpy(dct_file["dct_cr"])

        dct = torch.stack([y, cb, cr], dim=0).unsqueeze(0).float()
        dct = sd2(dct)[0]
        s.update(dct)
        # dct = to_numpy()

        # global_mean += dct.mean(axis=(1, 2))
        # global_var += dct.std(axis=(1, 2)) ** 2
        # n_items += 1

    return s.mean, s.std
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default=os.environ.get("KAGGLE_2020_ALASKA2"))

    args = parser.parse_args()
    data_dir = args.data_dir

    test_dir = os.path.join(data_dir, "Test")
    dataset = fs.find_images_in_dir(test_dir)

    # dataset = dataset[:500]
    df = defaultdict(list)
    for image_fname in tqdm(dataset):
        dct_fname = fs.change_extension(image_fname, ".npz")
        dct_data = np.load(dct_fname)
        qm0 = dct_data["qm0"]
        qm1 = dct_data["qm1"]
        qf = quality_factror_from_qm(qm0)
        fsize = os.stat(image_fname).st_size

        df["image_id"].append(os.path.basename(image_fname))
        df["quality"].append(qf)
        df["qm0"].append(qm0.flatten().tolist())
        df["qm1"].append(qm1.flatten().tolist())
        df["file_size"].append(fsize)

    df = pd.DataFrame.from_dict(df)
    df.to_csv("test_dataset_qf_qt.csv", index=False)
Ejemplo n.º 6
0
def compute_statistics(cover_fname):
    results_df = defaultdict(list)
    # cover_dct = np.load(fs.change_extension(cover_fname, ".npz"))

    cover = read_from_dct(cover_fname)

    for method_name in ["JMiPOD", "JUNIWARD", "UERD"]:
        stego_fname = cover_fname.replace("Cover", method_name)
        stego = read_from_dct(stego_fname)

        mask_fname = fs.change_extension(stego_fname, ".png")
        mask = compute_mask(cover, stego)

        results_df["image"].append(os.path.basename(cover_fname))
        results_df["method"].append(os.path.basename(method_name))
        results_df["pd"].append(count_pixel_difference(cover, stego))

        cv2.imwrite(mask_fname, ((mask > 0) * 255).astype(np.uint8))
        # stego_dct = np.load(fs.change_extension(stego_fname, ".npz"))

        # dct_y, dct_cr, dct_cb = count_dct_difference(cover_dct, stego_dct)
        # results_df["dct_total"].append(dct_y + dct_cr + dct_cb)
        # results_df["dct_y"].append(dct_y)
        # results_df["dct_cr"].append(dct_cr)
        # results_df["dct_cb"].append(dct_cb)
        #
        # dct_y, dct_cr, dct_cb = count_dct_difference_bits(cover_dct, stego_dct)
        # results_df["dct_bits_total"].append(dct_y + dct_cr + dct_cb)
        # results_df["dct_bits_y"].append(dct_y)
        # results_df["dct_bits_cr"].append(dct_cr)
        # results_df["dct_bits_cb"].append(dct_cb)

    return results_df
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('input', nargs='+')
    parser.add_argument('--need-features', action='store_true')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=multiprocessing.cpu_count(),
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-w', '--workers', type=int, default=4, help='')

    args = parser.parse_args()
    need_features = args.need_features
    batch_size = args.batch_size
    num_workers = args.workers
    checkpoint_fname = args.input  # pass just single checkpoint filename as arg
    '''Not Changing variables'''
    data_dir = '/opt/ml/code/'
    checkpoint_path = os.path.join(data_dir, 'model', checkpoint_fname)
    current_milli_time = lambda: str(round(time.time() * 1000))

    if torch.cuda.is_available():
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    params = checkpoint['checkpoint_data']['cmd_args']

    # Make OOF predictions
    images_dir = os.path.join(data_dir, "ratinopathy", current_milli_time())

    retino = pd.read_csv(os.path.join(data_dir, 'aptos-2019', 'test.csv'))
    '''Downloading fundus photography files'''
    for id_code in retino['id_code']:
        download_from_s3(s3_filename="aptos-2019/train.csv",
                         local_path=os.path.join(images_dir, id_code))

    image_paths = retino['id_code'].apply(
        lambda x: image_with_name_in_dir(images_dir, x))

    # Now run inference on Aptos2019 public test, will return a pd.DataFrame having image_id, logits, regrssions, ordinal, features
    ratinopathy = run_model_inference(checkpoint=checkpoint,
                                      params=params,
                                      apply_softmax=True,
                                      need_features=need_features,
                                      retino=retino,
                                      image_paths=image_paths,
                                      batch_size=batch_size,
                                      tta='fliplr',
                                      workers=num_workers,
                                      crop_black=True)
    ratinopathy.to_pickle(
        fs.change_extension(checkpoint_fname, '_ratinopathy_predictions.pkl'))
Ejemplo n.º 8
0
def csv_from_excel(fname, sheet_name='Feuil1') -> pd.DataFrame:
    wb = xlrd.open_workbook(fname)
    sh = wb.sheet_by_name(sheet_name)
    csv_file = change_extension(fname, '.csv')
    your_csv_file = open(csv_file, 'w', encoding='utf-8')
    wr = csv.writer(your_csv_file, quoting=csv.QUOTE_ALL)

    for rownum in range(sh.nrows):
        wr.writerow(sh.row_values(rownum))

    your_csv_file.close()
    return pd.read_csv(csv_file)
Ejemplo n.º 9
0
def compute_mean_std(dataset):
    """
    https://stats.stackexchange.com/questions/25848/how-to-sum-a-standard-deviation
    """

    # global_mean = np.zeros((3 * 64), dtype=np.float64)
    # global_var = np.zeros((3 * 64), dtype=np.float64)

    n_items = 0
    s = RunningStatistics()

    for image_fname in dataset:
        dct_file = fs.change_extension(image_fname, ".npz")
        residual = compute_decoding_residual(cv2.imread(image_fname), dct_file)
        s.update(torch.from_numpy(residual).permute(2, 0, 1))

    return s.mean, s.std
Ejemplo n.º 10
0
def get_predictions(models: List[str], datasets: List[str]) -> List[str]:
    models_predictions = []
    for dataset in datasets:
        assert dataset in {
            'aptos2015_test_private', 'aptos2015_test_public',
            'aptos2015_train', 'aptos2019_test', 'messidor2_train',
            'idrid_test'
        }

        for model_name in models:
            if model_name in MODELS:
                model_checkpoints = MODELS[model_name]  # Well-known models
            else:
                model_checkpoints = [model_name]  # Random stuff

            for model_checkpoint in model_checkpoints:
                predictions = fs.change_extension(
                    model_checkpoint, f'_{dataset}_predictions.pkl')
                models_predictions.append(predictions)
    return models_predictions
Ejemplo n.º 11
0
def compute_features_proc(image_fname):
    dct_file = fs.change_extension(image_fname, ".npz")
    image = 2 * (decode_bgr_from_dct(dct_file) / 140 - 0.5)

    entropy_per_channel = [
        entropy(image[..., 0].flatten()),
        entropy(image[..., 1].flatten()),
        entropy(image[..., 2].flatten()),
    ]

    f = [
        image[..., 0].mean(),
        image[..., 1].mean(),
        image[..., 2].mean(),
        image[..., 0].std(),
        image[..., 1].std(),
        image[..., 2].std(),
        entropy_per_channel[0],
        entropy_per_channel[1],
        entropy_per_channel[2],
    ]
    return f
Ejemplo n.º 12
0
def get_predictions_csv(experiment,
                        metric: str,
                        type: str,
                        tta: str = None,
                        need_embedding=False):
    if isinstance(experiment, list):
        return [
            get_predictions_csv(x,
                                metric=metric,
                                type=type,
                                tta=tta,
                                need_embedding=need_embedding)
            for x in experiment
        ]

    embedding_suffix = "_w_emb" if need_embedding else ""

    assert type in {"test", "holdout", "oof", "train"}
    assert metric in {"loss", "bauc", "cauc"}
    assert tta in {None, "d4", "hv"}
    checkpoints_dir = {
        "loss": "checkpoints",
        "bauc": "checkpoints_auc",
        "cauc": "checkpoints_auc_classifier"
    }[metric]
    csv = os.path.join("models", experiment, "main", checkpoints_dir,
                       f"best_{type}_predictions{embedding_suffix}.csv")

    if tta == "d4":
        csv = as_d4_tta([csv])[0]
    elif tta == "hv":
        csv = as_hv_tta([csv])[0]

    if need_embedding:
        csv = fs.change_extension(csv, ".pkl")

    return csv
Ejemplo n.º 13
0
def as_d4_tta(predictions):
    return [fs.change_extension(x, "_d4_tta.csv") for x in predictions]
Ejemplo n.º 14
0
def main():
    output_dir = os.path.dirname(__file__)

    checksum = "DCTR_JRM_B4_B5_B6_MixNet_XL_SRNET"
    columns = [
        # "DCTR",
        # "JRM",
        # "MixNet_xl_pc",
        # "MixNet_xl_pjm",
        # "MixNet_xl_pjuni",
        # "MixNet_xl_puerd",
        # "efn_b4_pc",
        # "efn_b4_pjm",
        # "efn_b4_pjuni",
        # "efn_b4_puerd",
        # "efn_b2_pc",
        # "efn_b2_pjm",
        # "efn_b2_pjuni",
        # "efn_b2_puerd",
        # "MixNet_s_pc",
        # "MixNet_s_pjm",
        # "MixNet_s_pjuni",
        # "MixNet_s_puerd",
        # "SRNet_pc",
        # "SRNet_pjm",
        # "SRNet_pjuni",
        # "SRNet_puerd",
        # "SRNet_noPC70_pc",
        # "SRNet_noPC70_pjm",
        # "SRNet_noPC70_pjuni",
        # "SRNet_noPC70_puerd",
        "efn_b4_mish_pc",
        "efn_b4_mish_pjm",
        "efn_b4_mish_pjuni",
        "efn_b4_mish_puerd",
        "efn_b5_mish_pc",
        "efn_b5_mish_pjm",
        "efn_b5_mish_pjuni",
        "efn_b5_mish_puerd",
        # "efn_b2_NR_mish_pc",
        # "efn_b2_NR_mish_pjm",
        # "efn_b2_NR_mish_pjuni",
        # "efn_b2_NR_mish_puerd",
        "MixNet_xl_mish_pc",
        "MixNet_xl_mish_pjm",
        "MixNet_xl_mish_pjuni",
        "MixNet_xl_mish_puerd",
        "efn_b6_NR_mish_pc",
        "efn_b6_NR_mish_pjm",
        "efn_b6_NR_mish_pjuni",
        "efn_b6_NR_mish_puerd",
        "SRNet_noPC70_mckpt_pc",
        "SRNet_noPC70_mckpt_pjm",
        "SRNet_noPC70_mckpt_pjuni",
        "SRNet_noPC70_mckpt_puerd",
    ]
    x, y, quality_h, image_ids = get_x_y_for_stacking(
        "probabilities_zoo_holdout_0718.csv", columns)
    print(x.shape, y.shape)

    x_test, _, quality_t, image_ids_test = get_x_y_for_stacking(
        "probabilities_zoo_lb_0718.csv", columns)
    print(x_test.shape)

    if True:
        sc = StandardScaler()
        x = sc.fit_transform(x)
        x_test = sc.transform(x_test)

    if False:
        sc = PCA(n_components=16)
        x = sc.fit_transform(x)
        x_test = sc.transform(x_test)

    if True:
        quality_h = F.one_hot(torch.tensor(quality_h).long(),
                              3).numpy().astype(np.float32)
        quality_t = F.one_hot(torch.tensor(quality_t).long(),
                              3).numpy().astype(np.float32)

        x = np.column_stack([x, quality_h])
        x_test = np.column_stack([x_test, quality_t])

    group_kfold = GroupKFold(n_splits=5)

    params = {
        "booster": ["gbtree", "gblinear"],
        "min_child_weight": [1, 5, 10],
        # L2 reg
        "lambda": [0, 0.01, 0.1, 1],
        # L1 reg
        "alpha": [0, 0.01, 0.1, 1],
        "subsample": [0.6, 0.8, 1.0],
        "colsample_bytree": [0.6, 0.8, 1.0],
        "max_depth": [2, 3, 4, 5, 6],
        "n_estimators": [16, 32, 64, 128, 256, 1000],
        "learning_rate": [0.001, 0.01, 0.05, 0.2, 1],
    }

    xgb = XGBClassifier(objective="binary:logistic", gamma=1e-4, nthread=1)

    random_search = RandomizedSearchCV(
        xgb,
        param_distributions=params,
        scoring=make_scorer(alaska_weighted_auc,
                            greater_is_better=True,
                            needs_proba=True),
        n_jobs=4,
        n_iter=100,
        cv=group_kfold.split(x, y, groups=image_ids),
        verbose=3,
        random_state=42,
    )

    # Here we go
    random_search.fit(x, y)

    print("\n All results:")
    print(random_search.cv_results_)
    print("\n Best estimator:")
    print(random_search.best_estimator_)
    print(random_search.best_score_)
    print("\n Best hyperparameters:")
    print(random_search.best_params_)
    results = pd.DataFrame(random_search.cv_results_)
    results.to_csv("xgb-2-random-grid-search-results-01.csv", index=False)

    test_pred = random_search.predict_proba(x_test)[:, 1]

    submit_fname = os.path.join(
        output_dir,
        f"xgb_cls2_gs_{random_search.best_score_:.4f}_{checksum}_.csv")

    df = {}
    df["Label"] = test_pred
    df["Id"] = image_ids_test
    pd.DataFrame.from_dict(df).to_csv(submit_fname, index=False)
    print("Saved submission to ", submit_fname)

    import json

    with open(fs.change_extension(submit_fname, ".json"), "w") as f:
        json.dump(random_search.best_params_, f, indent=2)

    print("Features importance")
    print(random_search.best_estimator_.feature_importances_)
Ejemplo n.º 15
0
def convert_dir(df: pd.DataFrame, dir) -> pd.DataFrame:
    crops_dir = os.path.join(dir, "crops")
    os.makedirs(crops_dir, exist_ok=True)

    building_crops = []

    global_crop_index = 0

    for i, row in tqdm(df.iterrows(), total=len(df)):
        image_fname_pre = read_image(os.path.join(dir, row["image_fname_pre"]))
        image_fname_post = read_image(
            os.path.join(dir, row["image_fname_post"]))

        mask_fname_post = row["mask_fname_post"]
        json_fname_post = fs.change_extension(
            mask_fname_post.replace("masks", "labels"), ".json")
        inference_data = open_json(os.path.join(dir, json_fname_post))
        instance_image, labels = create_instance_image(inference_data)

        for label_index, damage_label in zip(
                range(1,
                      instance_image.max() + 1), labels):
            try:
                instance_mask = instance_image == label_index
                rmin, rmax, cmin, cmax = bbox1(instance_mask)

                max_size = max(rmax - rmin, cmax - cmin)
                if max_size < 16:
                    print("Skipping crop since it's too small",
                          fs.id_from_fname(mask_fname_post), "label_index",
                          label_index, "min_size", max_size)
                    continue

                rpadding = (rmax - rmin) // 4
                cpadding = (cmax - cmin) // 4

                pre_crop = image_fname_pre[max(0, rmin -
                                               rpadding):rmax + rpadding,
                                           max(0, cmin - cpadding):cmax +
                                           cpadding]
                post_crop = image_fname_post[max(0, rmin -
                                                 rpadding):rmax + rpadding,
                                             max(0, cmin - cpadding):cmax +
                                             cpadding]

                image_id_pre = row["image_id_pre"]
                image_id_post = row["image_id_post"]

                pre_crop_fname = f"{global_crop_index:06}_{image_id_pre}.png"
                post_crop_fname = f"{global_crop_index:06}_{image_id_post}.png"
                global_crop_index += 1

                cv2.imwrite(os.path.join(crops_dir, pre_crop_fname), pre_crop)
                cv2.imwrite(os.path.join(crops_dir, post_crop_fname),
                            post_crop)

                building_crops.append({
                    "pre_crop_fname": pre_crop_fname,
                    "post_crop": post_crop_fname,
                    "label": damage_label,
                    "event_name": row["event_name_post"],
                    "fold": row["fold_post"],
                    "rmin": rmin,
                    "rmax": rmax,
                    "cmin": cmin,
                    "cmax": cmax,
                    "max_size": max_size,
                    "rpadding": rpadding,
                    "cpadding": cpadding
                })
            except Exception as e:
                print(e)
                print(mask_fname_post)

    df = pd.DataFrame.from_records(building_crops)
    return df
Ejemplo n.º 16
0
def main():
    # Give no chance to randomness
    torch.manual_seed(0)
    np.random.seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoint", type=str, nargs="+")
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-b", "--batch-size", type=int, default=1)
    parser.add_argument("-w", "--workers", type=int, default=0)
    parser.add_argument("-d4", "--d4-tta", action="store_true")
    parser.add_argument("-hv", "--hv-tta", action="store_true")
    parser.add_argument("-f", "--force-recompute", action="store_true")
    parser.add_argument("-fp16", "--fp16", action="store_true")

    args = parser.parse_args()

    checkpoint_fnames = args.checkpoint
    data_dir = args.data_dir
    batch_size = args.batch_size
    workers = args.workers
    fp16 = args.fp16
    d4_tta = args.d4_tta
    force_recompute = args.force_recompute
    need_embedding = True

    outputs = [
        OUTPUT_PRED_MODIFICATION_FLAG, OUTPUT_PRED_MODIFICATION_TYPE,
        OUTPUT_PRED_EMBEDDING
    ]
    embedding_suffix = "_w_emb" if need_embedding else ""

    for checkpoint_fname in checkpoint_fnames:
        model, checkpoints, required_features = ensemble_from_checkpoints(
            [checkpoint_fname],
            strict=True,
            outputs=outputs,
            activation=None,
            tta=None,
            need_embedding=need_embedding)

        report_checkpoint(checkpoints[0])

        model = model.cuda()
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model = model.eval()

        if fp16:
            model = model.half()

        train_ds = get_train_except_holdout(data_dir,
                                            features=required_features)
        holdout_ds = get_holdout(data_dir, features=required_features)
        test_ds = get_test_dataset(data_dir, features=required_features)

        if d4_tta:
            model = wrap_model_with_tta(model,
                                        "d4",
                                        inputs=required_features,
                                        outputs=outputs).eval()
            tta_suffix = "_d4_tta"
        else:
            tta_suffix = ""

        # Train
        trn_predictions_csv = fs.change_extension(
            checkpoint_fname,
            f"_train_predictions{embedding_suffix}{tta_suffix}.pkl")
        if force_recompute or not os.path.exists(trn_predictions_csv):
            trn_predictions = compute_trn_predictions(model,
                                                      train_ds,
                                                      fp16=fp16,
                                                      batch_size=batch_size,
                                                      workers=workers)
            trn_predictions.to_pickle(trn_predictions_csv)

        # Holdout
        hld_predictions_csv = fs.change_extension(
            checkpoint_fname,
            f"_holdout_predictions{embedding_suffix}{tta_suffix}.pkl")
        if force_recompute or not os.path.exists(hld_predictions_csv):
            hld_predictions = compute_trn_predictions(model,
                                                      holdout_ds,
                                                      fp16=fp16,
                                                      batch_size=batch_size,
                                                      workers=workers)
            hld_predictions.to_pickle(hld_predictions_csv)

        # Test
        tst_predictions_csv = fs.change_extension(
            checkpoint_fname,
            f"_test_predictions{embedding_suffix}{tta_suffix}.pkl")
        if force_recompute or not os.path.exists(tst_predictions_csv):
            tst_predictions = compute_trn_predictions(model,
                                                      test_ds,
                                                      fp16=fp16,
                                                      batch_size=batch_size,
                                                      workers=workers)
            tst_predictions.to_pickle(tst_predictions_csv)
Ejemplo n.º 17
0
def main():
    # Give no chance to randomness
    torch.manual_seed(0)
    np.random.seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoint", type=str, nargs="+")
    parser.add_argument("-dd", "--data-dir", type=str, default=os.environ.get("KAGGLE_2020_ALASKA2"))
    parser.add_argument("-b", "--batch-size", type=int, default=1)
    parser.add_argument("-w", "--workers", type=int, default=0)
    parser.add_argument("-d4", "--d4-tta", action="store_true")
    parser.add_argument("-hv", "--hv-tta", action="store_true")
    parser.add_argument("-f", "--force-recompute", action="store_true")
    parser.add_argument("-oof", "--need-oof", action="store_true")

    args = parser.parse_args()

    checkpoint_fnames = args.checkpoint
    data_dir = args.data_dir
    batch_size = args.batch_size
    workers = args.workers

    d4_tta = args.d4_tta
    hv_tta = args.hv_tta
    force_recompute = args.force_recompute
    outputs = [OUTPUT_PRED_MODIFICATION_FLAG, OUTPUT_PRED_MODIFICATION_TYPE]

    for checkpoint_fname in checkpoint_fnames:
        model, checkpoints, required_features = ensemble_from_checkpoints(
            [checkpoint_fname], strict=True, outputs=outputs, activation=None, tta=None
        )

        report_checkpoint(checkpoints[0])

        model = model.cuda()
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model = model.eval()

        # Holdout
        variants = {
            "istego100k_test_same_center_crop": get_istego100k_test_same(
                data_dir, features=required_features, output_size="center_crop"
            ),
            "istego100k_test_same_full": get_istego100k_test_same(
                data_dir, features=required_features, output_size="full"
            ),
            "istego100k_test_other_center_crop": get_istego100k_test_other(
                data_dir, features=required_features, output_size="center_crop"
            ),
            "istego100k_test_other_full": get_istego100k_test_other(
                data_dir, features=required_features, output_size="full"
            ),
            "holdout": get_holdout("d:\datasets\ALASKA2", features=required_features),
        }

        for name, dataset in variants.items():
            print("Making predictions for ", name, len(dataset))

            predictions_csv = fs.change_extension(checkpoint_fname, f"_{name}_predictions.csv")
            if force_recompute or not os.path.exists(predictions_csv):
                holdout_predictions = compute_oof_predictions(
                    model, dataset, batch_size=batch_size // 4 if "full" in name else batch_size, workers=workers
                )
                holdout_predictions.to_csv(predictions_csv, index=False)
                holdout_predictions = pd.read_csv(predictions_csv)

                print(name)
                print(
                    "\tbAUC",
                    alaska_weighted_auc(
                        holdout_predictions[INPUT_TRUE_MODIFICATION_FLAG].values,
                        holdout_predictions[OUTPUT_PRED_MODIFICATION_FLAG].apply(sigmoid).values,
                    ),
                )

                print(
                    "\tcAUC",
                    alaska_weighted_auc(
                        holdout_predictions[INPUT_TRUE_MODIFICATION_FLAG].values,
                        holdout_predictions[OUTPUT_PRED_MODIFICATION_TYPE].apply(parse_classifier_probas).values,
                    ),
                )
def main():
    start = time.time()

    torch.set_num_threads(1)
    torch.set_num_interop_threads(1)

    parser = argparse.ArgumentParser()
    parser.add_argument("pre_image", type=str)
    parser.add_argument("post_image", type=str)
    parser.add_argument("loc_image", type=str)
    parser.add_argument("dmg_image", type=str)
    parser.add_argument("--raw", action="store_true")
    parser.add_argument("--color-mask", action="store_true")
    parser.add_argument("--gpu", action="store_true")
    args = parser.parse_args()

    pre_image = args.pre_image
    post_image = args.post_image
    localization_fname = args.loc_image
    damage_fname = args.dmg_image
    save_raw = args.raw
    color_mask = args.color_mask
    use_gpu = args.gpu

    size = 1024
    postprocess = "naive"
    image_size = size, size

    print("pre_image   ", pre_image)
    print("post_image  ", post_image)
    print("loc_image   ", localization_fname)
    print("dmg_image   ", damage_fname)
    print("Size        ", image_size)
    print("Postprocess ", postprocess)
    print("Colorize    ", color_mask)
    raw_predictions_file = fs.change_extension(damage_fname, ".npy")
    print("raw_predictions_file", raw_predictions_file)
    print(*torch.__config__.show().split("\n"), sep="\n")

    if not os.path.isdir(os.path.dirname(localization_fname)):
        print("Output directory does not exists", localization_fname)
        return -1

    if not os.access(os.path.dirname(localization_fname), os.W_OK):
        print("Output directory does not have write access",
              localization_fname)
        return -2

    if not os.path.isdir(os.path.dirname(damage_fname)):
        print("Output directory does not exists", damage_fname)
        return -1

    if not os.access(os.path.dirname(damage_fname), os.W_OK):
        print("Output directory does not have write access", damage_fname)
        return -2

    fold_0_models_dict = [
        # (
        #     "Dec15_21_41_resnet101_fpncatv2_256_512_fold0_fp16_crops.pth",
        #     [0.45136154, 1.4482629, 1.42098208, 0.6839698, 0.96800456],
        # ),
        # (
        #     "Dec16_08_26_resnet34_unet_v2_512_fold0_fp16_crops.pth",
        #     [0.92919105, 1.03831743, 1.03017048, 0.98257118, 1.0241164],
        # ),
        # (
        #     "Dec21_21_54_densenet161_deeplab256_512_fold0_fp16_crops.pth",
        #     [0.48157651, 1.02084685, 1.36264406, 1.03175205, 1.11758873],
        # ),
        # 0.762814651939279 0.854002889559006 0.7237339786736817 [0.9186602573598759, 0.5420118318644089, 0.7123870673168781, 0.8405837378060299] coeffs [0.51244243 1.42747062 1.23648384 0.90290896 0.88912514]
        (
            "Dec30_15_34_resnet34_unet_v2_512_fold0_fp16_pseudo_crops.pth",
            [0.51244243, 1.42747062, 1.23648384, 0.90290896, 0.88912514],
        ),
        # 0.7673669954814148 0.8582940771677703 0.7283982461872626 [0.919932857782992, 0.5413880912001547, 0.731840942842999, 0.8396640419159087] coeffs [0.50847073 1.15392272 1.2059733  1.1340391  1.03196719]
        (
            "Dec30_15_34_resnet101_fpncatv2_256_512_fold0_fp16_pseudo_crops.pth",
            [0.50847073, 1.15392272, 1.2059733, 1.1340391, 1.03196719],
        ),
    ]

    fold_1_models_dict = [
        # (
        #     "Dec16_18_59_densenet201_fpncatv2_256_512_fold1_fp16_crops.pth",
        #     [0.64202075, 1.04641224, 1.23015655, 1.03203408, 1.12505602],
        # ),
        # (
        #     "Dec17_01_52_resnet34_unet_v2_512_fold1_fp16_crops.pth",
        #     [0.69605759, 0.89963168, 0.9232137, 0.92938775, 0.94460875],
        # ),
        (
            "Dec22_22_24_seresnext50_unet_v2_512_fold1_fp16_crops.pth",
            [0.54324459, 1.76890163, 1.20782899, 0.85128004, 0.83100698],
        ),
        (
            "Dec31_02_09_resnet34_unet_v2_512_fold1_fp16_pseudo_crops.pth",
            # Maybe suboptimal
            [0.48269921, 1.22874469, 1.38328066, 0.96695393, 0.91348539],
        ),
        ("Dec31_03_55_densenet201_fpncatv2_256_512_fold1_fp16_pseudo_crops.pth",
         [0.48804137, 1.14809462, 1.24851827, 1.11798428, 1.00790482])
    ]

    fold_2_models_dict = [
        # (
        #     "Dec17_19_19_resnet34_unet_v2_512_fold2_fp16_crops.pth",
        #     [0.65977938, 1.50252452, 0.97098732, 0.74048182, 1.08712367],
        # ),
        # 0.7674290884579319 0.8107652756500724 0.7488564368041575 [0.9228529822124596, 0.5900700454049471, 0.736806959757804, 0.8292099253270483] coeffs [0.34641084 1.63486251 1.14186036 0.86668715 1.12193125]
        (
            "Dec17_19_12_inceptionv4_fpncatv2_256_512_fold2_fp16_crops.pth",
            [0.34641084, 1.63486251, 1.14186036, 0.86668715, 1.12193125],
        ),
        # 0.7683650436367244 0.8543981047493 0.7314937317313349 [0.9248137307721042, 0.5642011151253543, 0.7081016179096937, 0.831720163492164] coeffs [0.51277498 1.4475809  0.8296623  0.97868596 1.34180805]
        (
            "Dec27_14_08_densenet169_unet_v2_512_fold2_fp16_crops.pth",
            [0.55429115, 1.34944309, 1.1087044, 0.89542089, 1.17257541],
        ),
        (
            "Dec31_12_45_resnet34_unet_v2_512_fold2_fp16_pseudo_crops.pth",
            # Copied from Dec17_19_19_resnet34_unet_v2_512_fold2_fp16_crops
            [0.65977938, 1.50252452, 0.97098732, 0.74048182, 1.08712367],
        )
    ]

    fold_3_models_dict = [
        (
            "Dec15_23_24_resnet34_unet_v2_512_fold3_crops.pth",
            [0.84090623, 1.02953555, 1.2526516, 0.9298182, 0.94053529],
        ),
        # (
        #     "Dec18_12_49_resnet34_unet_v2_512_fold3_fp16_crops.pth",
        #     [0.55555375, 1.18287119, 1.10997173, 0.85927596, 1.18145368],
        # ),
        # (
        #     "Dec19_14_59_efficientb4_fpncatv2_256_512_fold3_fp16_crops.pth",
        #     [0.59338243, 1.17347438, 1.186104, 1.06860638, 1.03041829],
        # ),
        (
            "Dec21_11_50_seresnext50_unet_v2_512_fold3_fp16_crops.pth",
            [0.43108046, 1.30222898, 1.09660616, 0.94958969, 1.07063753],
        ),
        (
            "Dec31_18_17_efficientb4_fpncatv2_256_512_fold3_fp16_pseudo_crops.pth",
            # Copied from Dec19_14_59_efficientb4_fpncatv2_256_512_fold3_fp16_crops
            [0.59338243, 1.17347438, 1.186104, 1.06860638, 1.03041829])
    ]

    fold_4_models_dict = [
        (
            "Dec19_06_18_resnet34_unet_v2_512_fold4_fp16_crops.pth",
            [0.83915734, 1.02560309, 0.77639015, 1.17487775, 1.05632771],
        ),
        (
            "Dec27_14_37_resnet101_unet_v2_512_fold4_fp16_crops.pth",
            [0.57414314, 1.19599486, 1.05561912, 0.98815567, 1.2274592],
        ),
    ]

    infos = []

    resize = A.Resize(1024, 1024)
    normalize = A.Normalize(mean=(0.485, 0.456, 0.406, 0.485, 0.456, 0.406),
                            std=(0.229, 0.224, 0.225, 0.229, 0.224, 0.225))
    transform = A.Compose([resize, normalize])

    # Very dumb way but it matches 1:1 with inferencing
    pre, post = read_image(pre_image), read_image(post_image)
    image = np.dstack([pre, post])
    image = transform(image=image)["image"]
    pre_image = image[..., 0:3]
    post_image = image[..., 3:6]

    models = []
    for models_dict in [
            fold_0_models_dict,
            fold_1_models_dict,
            fold_2_models_dict,
            fold_3_models_dict,
            fold_4_models_dict,
    ]:
        for checkpoint, weights in models_dict:
            model, info = weighted_model(checkpoint,
                                         weights,
                                         activation="model")
            models.append(model)
            infos.append(info)

    model = Ensembler(models, outputs=[OUTPUT_MASK_KEY])
    model = HFlipTTA(model, outputs=[OUTPUT_MASK_KEY], average=True)
    model = MultiscaleTTA(model,
                          outputs=[OUTPUT_MASK_KEY],
                          size_offsets=[-128, +128],
                          average=True)
    model = model.eval()

    df = pd.DataFrame.from_records(infos)
    pd.set_option("display.max_rows", None)
    pd.set_option("display.max_columns", None)
    pd.set_option("display.width", None)
    pd.set_option("display.max_colwidth", -1)

    print(df)
    print("score        ", df["score"].mean(), df["score"].std())
    print("localization ", df["localization"].mean(), df["localization"].std())
    print("damage       ", df["damage"].mean(), df["damage"].std())

    input_image = tensor_from_rgb_image(np.dstack([pre_image,
                                                   post_image])).unsqueeze(0)

    if use_gpu:
        print("Using GPU for inference")
        input_image = input_image.cuda()
        model = model.cuda()

    output = model(input_image)
    masks = output[OUTPUT_MASK_KEY]
    predictions = to_numpy(masks.squeeze(0)).astype(np.float32)

    if save_raw:
        np.save(raw_predictions_file, predictions)

    localization_image, damage_image = make_predictions_naive(predictions)

    if color_mask:
        localization_image = colorize_mask(localization_image)
        localization_image.save(localization_fname)

        damage_image = colorize_mask(damage_image)
        damage_image.save(damage_fname)
    else:
        cv2.imwrite(localization_fname, localization_image)
        cv2.imwrite(damage_fname, damage_image)

    print("Saved output to ", localization_fname, damage_fname)

    done = time.time()
    elapsed = done - start
    print("Inference time", elapsed, "(s)")
Ejemplo n.º 19
0
def run_inference_on_dataset_oof(model,
                                 dataset,
                                 output_dir,
                                 batch_size=1,
                                 workers=0,
                                 save=True,
                                 fp16=False):
    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model = model.eval()
    if fp16:
        model = model.half()

    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             pin_memory=True,
                             num_workers=workers)

    if save:
        os.makedirs(output_dir, exist_ok=True)

    allrows = []

    for batch in tqdm(data_loader):
        image = batch[INPUT_IMAGE_KEY]
        if fp16:
            image = image.half()

        image = image.cuda(non_blocking=True)
        image_ids = batch[INPUT_IMAGE_ID_KEY]

        dmg_true = to_numpy(batch[INPUT_MASK_KEY]).astype(np.float32)

        output = model(image)

        masks = output[OUTPUT_MASK_KEY]
        masks = to_numpy(masks)

        for i, image_id in enumerate(image_ids):
            damage_mask = masks[i]

            if save:
                damage_fname = os.path.join(
                    output_dir,
                    fs.change_extension(image_id.replace("_pre", "_post"),
                                        ".npy"))
                np.save(damage_fname, damage_mask.astype(np.float16))

            loc_pred, dmg_pred = make_predictions_naive(damage_mask)
            row = CompetitionMetricCallback.get_row_pair(
                loc_pred, dmg_pred, dmg_true[i], dmg_true[i])
            allrows.append(row)

        if save:
            if DAMAGE_TYPE_KEY in output:
                damage_type = to_numpy(
                    output[DAMAGE_TYPE_KEY].sigmoid()).astype(np.float32)

                for i, image_id in enumerate(image_ids):
                    damage_fname = os.path.join(
                        output_dir,
                        fs.change_extension(
                            image_id.replace("_pre", "_damage_type"), ".npy"))
                    np.save(damage_fname, damage_type[i])

    del data_loader

    return CompetitionMetricCallback.compute_metrics(allrows)
Ejemplo n.º 20
0
def read_from_dct(image_fname):
    return decode_bgr_from_dct(fs.change_extension(image_fname, ".npz"))
Ejemplo n.º 21
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("experiments", nargs="+", type=str)
    parser.add_argument("-o", "--output", type=str, required=False)
    parser.add_argument("-dd",
                        "--data-dir",
                        type=str,
                        default=os.environ.get("KAGGLE_2020_ALASKA2"))
    args = parser.parse_args()

    output_dir = os.path.dirname(__file__)
    data_dir = args.data_dir
    experiments = args.experiments
    output_file = args.output

    holdout_predictions = get_predictions_csv(experiments, "cauc", "holdout",
                                              "d4")
    test_predictions = get_predictions_csv(experiments, "cauc", "test", "d4")
    checksum = compute_checksum_v2(experiments)

    holdout_ds = get_holdout("", features=[INPUT_IMAGE_KEY])
    image_ids_h = [fs.id_from_fname(x) for x in holdout_ds.images]
    quality_h = F.one_hot(torch.tensor(holdout_ds.quality).long(),
                          3).numpy().astype(np.float32)

    test_ds = get_test_dataset("", features=[INPUT_IMAGE_KEY])
    quality_t = F.one_hot(torch.tensor(test_ds.quality).long(),
                          3).numpy().astype(np.float32)

    with_logits = True
    x, y = get_x_y_for_stacking(holdout_predictions,
                                with_logits=with_logits,
                                tta_logits=with_logits)
    # Force target to be binary
    y = (y > 0).astype(int)
    print(x.shape, y.shape)

    x_test, _ = get_x_y_for_stacking(test_predictions,
                                     with_logits=with_logits,
                                     tta_logits=with_logits)
    print(x_test.shape)

    if False:
        image_fnames_h = [
            os.path.join(data_dir, INDEX_TO_METHOD[method], f"{image_id}.jpg")
            for (image_id, method) in zip(image_ids_h, y)
        ]
        test_image_ids = pd.read_csv(test_predictions[0]).image_id.tolist()
        image_fnames_t = [
            os.path.join(data_dir, "Test", image_id)
            for image_id in test_image_ids
        ]

        entropy_t = compute_image_features(image_fnames_t)
        x_test = np.column_stack([x_test, entropy_t])

        # entropy_h = entropy_t.copy()
        # x = x_test.copy()

        entropy_h = compute_image_features(image_fnames_h)
        x = np.column_stack([x, entropy_h])
        print("Added image features", entropy_h.shape, entropy_t.shape)

    if True:
        sc = StandardScaler()
        x = sc.fit_transform(x)
        x_test = sc.transform(x_test)

    if False:
        sc = PCA(n_components=16)
        x = sc.fit_transform(x)
        x_test = sc.transform(x_test)

    if True:
        x = np.column_stack([x, quality_h])
        x_test = np.column_stack([x_test, quality_t])

    group_kfold = GroupKFold(n_splits=5)

    params = {
        "min_child_weight": [1, 5, 10],
        "gamma": [1e-3, 1e-2, 1e-2, 0.5, 2],
        "subsample": [0.6, 0.8, 1.0],
        "colsample_bytree": [0.6, 0.8, 1.0],
        "max_depth": [2, 3, 4, 5, 6],
        "n_estimators": [16, 32, 64, 128, 256, 1000],
        "learning_rate": [0.001, 0.01, 0.05, 0.2, 1],
    }

    xgb = XGBClassifier(objective="binary:logistic", nthread=1)

    random_search = RandomizedSearchCV(
        xgb,
        param_distributions=params,
        scoring=make_scorer(alaska_weighted_auc,
                            greater_is_better=True,
                            needs_proba=True),
        n_jobs=4,
        n_iter=25,
        cv=group_kfold.split(x, y, groups=image_ids_h),
        verbose=3,
        random_state=42,
    )

    # Here we go
    random_search.fit(x, y)

    print("\n All results:")
    print(random_search.cv_results_)
    print("\n Best estimator:")
    print(random_search.best_estimator_)
    print(random_search.best_score_)
    print("\n Best hyperparameters:")
    print(random_search.best_params_)
    results = pd.DataFrame(random_search.cv_results_)
    results.to_csv("xgb-random-grid-search-results-01.csv", index=False)

    test_pred = random_search.predict_proba(x_test)[:, 1]

    if output_file is None:
        with_logits_sfx = "_with_logits" if with_logits else ""
        submit_fname = os.path.join(
            output_dir,
            f"xgb_cls_gs_{random_search.best_score_:.4f}_{checksum}{with_logits_sfx}.csv"
        )
    else:
        submit_fname = output_file

    df = pd.read_csv(test_predictions[0]).rename(columns={"image_id": "Id"})
    df["Label"] = test_pred
    df[["Id", "Label"]].to_csv(submit_fname, index=False)
    print("Saved submission to ", submit_fname)

    import json

    with open(fs.change_extension(submit_fname, ".json"), "w") as f:
        json.dump(random_search.best_params_, f, indent=2)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('input', nargs='+')
    parser.add_argument('--need-features', action='store_true')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=None,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-w', '--workers', type=int, default=4, help='')

    args = parser.parse_args()

    need_features = args.need_features
    batch_size = args.batch_size
    num_workers = args.workers

    checkpoints = args.input
    for i, checkpoint_fname in enumerate(checkpoints):
        print(i, checkpoint_fname)

        # Make OOF predictions
        checkpoint = torch.load(checkpoint_fname)
        params = checkpoint['checkpoint_data']['cmd_args']
        image_size = params['size']
        data_dir = params['data_dir']

        # train_ds, valid_ds, train_sizes = get_datasets(data_dir=params['data_dir'],
        #                                                use_aptos2019=params['use_aptos2019'],
        #                                                use_aptos2015=params['use_aptos2015'],
        #                                                use_idrid=params['use_idrid'],
        #                                                use_messidor=params['use_messidor'],
        #                                                use_unsupervised=False,
        #                                                image_size=(image_size, image_size),
        #                                                augmentation=params['augmentations'],
        #                                                preprocessing=params['preprocessing'],
        #                                                target_dtype=int,
        #                                                coarse_grading=params.get('coarse', False),
        #                                                fold=i,
        #                                                folds=4)
        # print(len(valid_ds))
        # oof_predictions = run_model_inference_via_dataset(checkpoint_fname,
        #                                                   valid_ds,
        #                                                   apply_softmax=True,
        #                                                   need_features=need_features,
        #                                                   batch_size=batch_size,
        #                                                   workers=num_workers)

        # dst_fname = fs.change_extension(checkpoint_fname, '_oof_predictions.pkl')
        # oof_predictions.to_pickle(dst_fname)

        # Now run inference on holdout IDRID Test dataset
        idrid_test = run_model_inference(
            model_checkpoint=checkpoint_fname,
            apply_softmax=True,
            need_features=need_features,
            test_csv=pd.read_csv(
                os.path.join(data_dir, 'idrid', 'test_labels.csv')),
            data_dir=os.path.join(data_dir, 'idrid'),
            images_dir='test_images_768',
            batch_size=batch_size,
            tta='fliplr',
            workers=num_workers,
            crop_black=True)
        idrid_test.to_pickle(
            fs.change_extension(checkpoint_fname,
                                '_idrid_test_predictions.pkl'))

        # Now run inference on Messidor 2 Test dataset
        messidor2_train = run_model_inference(
            model_checkpoint=checkpoint_fname,
            apply_softmax=True,
            need_features=need_features,
            test_csv=pd.read_csv(
                os.path.join(data_dir, 'messidor_2', 'train_labels.csv')),
            data_dir=os.path.join(data_dir, 'messidor_2'),
            images_dir='train_images_768',
            batch_size=batch_size,
            tta='fliplr',
            workers=num_workers,
            crop_black=True)
        messidor2_train.to_pickle(
            fs.change_extension(checkpoint_fname,
                                '_messidor2_train_predictions.pkl'))

        # Now run inference on Aptos2019 public test
        aptos2019_test = run_model_inference(
            model_checkpoint=checkpoint_fname,
            apply_softmax=True,
            need_features=need_features,
            test_csv=pd.read_csv(
                os.path.join(data_dir, 'aptos-2019', 'test.csv')),
            data_dir=os.path.join(data_dir, 'aptos-2019'),
            images_dir='test_images_768',
            batch_size=batch_size,
            tta='fliplr',
            workers=num_workers,
            crop_black=True)
        aptos2019_test.to_pickle(
            fs.change_extension(checkpoint_fname,
                                '_aptos2019_test_predictions.pkl'))

        # Now run inference on Aptos2015 private test
        if True:
            aptos2015_df = pd.read_csv(
                os.path.join(data_dir, 'aptos-2015', 'test_labels.csv'))
            aptos2015_df = aptos2015_df[aptos2015_df['Usage'] == 'Private']
            aptos2015_test = run_model_inference(
                model_checkpoint=checkpoint_fname,
                apply_softmax=True,
                need_features=need_features,
                test_csv=aptos2015_df,
                data_dir=os.path.join(data_dir, 'aptos-2015'),
                images_dir='test_images_768',
                batch_size=batch_size,
                tta='fliplr',
                workers=num_workers,
                crop_black=True)
            aptos2015_test.to_pickle(
                fs.change_extension(checkpoint_fname,
                                    '_aptos2015_test_private_predictions.pkl'))

        if False:
            aptos2015_df = pd.read_csv(
                os.path.join(data_dir, 'aptos-2015', 'test_labels.csv'))
            aptos2015_df = aptos2015_df[aptos2015_df['Usage'] == 'Public']
            aptos2015_test = run_model_inference(
                model_checkpoint=checkpoint_fname,
                apply_softmax=True,
                need_features=need_features,
                test_csv=aptos2015_df,
                data_dir=os.path.join(data_dir, 'aptos-2015'),
                images_dir='test_images_768',
                batch_size=batch_size,
                tta='fliplr',
                workers=num_workers,
                crop_black=True)
            aptos2015_test.to_pickle(
                fs.change_extension(checkpoint_fname,
                                    '_aptos2015_test_public_predictions.pkl'))

        if False:
            aptos2015_df = pd.read_csv(
                os.path.join(data_dir, 'aptos-2015', 'train_labels.csv'))
            aptos2015_test = run_model_inference(
                model_checkpoint=checkpoint_fname,
                apply_softmax=True,
                need_features=need_features,
                test_csv=aptos2015_df,
                data_dir=os.path.join(data_dir, 'aptos-2015'),
                images_dir='train_images_768',
                batch_size=batch_size,
                tta='fliplr',
                workers=num_workers,
                crop_black=True)
            aptos2015_test.to_pickle(
                fs.change_extension(checkpoint_fname,
                                    '_aptos2015_train_predictions.pkl'))