Ejemplo n.º 1
0
def main():

    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=False,
                        default="configs/scwc_isic_bcdu.yaml")
    args = parser.parse_args()

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)

    gts = []
    prs = []

    folds = config["test"]["folds"]
    print(folds)
    dataset = config["dataset"]["test_data_path"][0].split("/")[-1]
    if len(folds.keys()) == 1:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_{list(folds.keys())[0]}_{dataset}.log',
            rotation="10 MB",
        )
    else:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_kfold.log',
            rotation="10 MB",
        )

    for id in list(folds.keys()):

        # FOR ORIDATASET
        test_img_paths = []
        test_mask_paths = []
        test_data_path = config["dataset"]["test_data_path"]
        for i in test_data_path:
            test_img_paths.extend(glob(os.path.join(i, "images", "*")))
            test_mask_paths.extend(glob(os.path.join(i, "masks", "*")))

        test_img_paths.sort()
        test_mask_paths.sort()
        for i in range(len(test_mask_paths)):
            if (os.path.basename(test_mask_paths[i]) != os.path.basename(
                    test_img_paths[i])):
                print(
                    f"{test_mask_paths[i]} and {test_img_paths[i]} is different"
                )
                import sys
                sys.exit()

        test_augprams = config["test"]["augment"]
        test_transform = Augmenter(**test_augprams)
        test_loader = get_loader(
            test_img_paths,
            test_mask_paths,
            # None,
            transform=test_transform,
            **config["test"]["dataloader"],
            type="test",
        )
        test_size = len(test_loader)

        epochs = folds[id]
        if type(epochs) != list:
            epochs = [5 * (epochs // 5) + 2]
        elif len(epochs) == 2:
            # epochs = [i for i in range(epochs[0], epochs[1])]
            # epochs = [
            #     5 * i - 1 for i in range(epochs[0] // 5 + 1, (epochs[1] + 1) // 5 + 1)
            # ]
            epochs = [
                3 * i - 1
                for i in range(epochs[0] // 3 + 1, (epochs[1] + 1) // 3 + 1)
            ]
        elif len(epochs) == 1:
            epochs = [5 * (epochs[0] // 5 + 1) - 1]
        else:
            logger.debug("Model path must have 0 or 1 num")
            break
        for e in epochs:
            # MODEL

            logger.info("Loading model")
            model_prams = config["model"]
            import network.models as models

            arch = model_prams["arch"]

            # TRANSUNET
            # n_skip = 3
            # vit_name = "R50-ViT-B_16"
            # vit_patches_size = 16
            # img_size = config["dataset"]["img_size"]
            # from network.models.transunet.vit_seg_modeling import (
            #     CONFIGS as CONFIGS_ViT_seg,
            # )
            # import numpy as np

            # config_vit = CONFIGS_ViT_seg[vit_name]
            # config_vit.n_classes = 1
            # config_vit.n_skip = n_skip
            # if vit_name.find("R50") != -1:
            #     config_vit.patches.grid = (
            #         int(img_size / vit_patches_size),
            #         int(img_size / vit_patches_size),
            #     )

            # model = models.__dict__[arch](
            #     config_vit, img_size=img_size, num_classes=config_vit.n_classes
            # )  # TransUnet
            dev = config["test"]["dev"]
            model = models.__dict__[arch]()  # Pranet
            if "save_dir" not in model_prams:
                save_dir = os.path.join("snapshots",
                                        model_prams["arch"] + "_kfold")
            else:
                save_dir = config["model"]["save_dir"]

            if ".pth" in save_dir:
                model_path = save_dir
            else:
                model_path = os.path.join(
                    save_dir,
                    f"PraNetDG-fold{id}-{e}.pth",
                )
            device = torch.device(dev)
            if dev == "cpu":
                model.cpu()
            else:
                model.cuda()
            model.eval()

            logger.info(f"Loading from {model_path}")
            try:

                # model.load_state_dict(torch.load(model_path)["model_state_dict"])
                model.load_state_dict(
                    torch.load(model_path,
                               map_location=device)["model_state_dict"])
            except RuntimeError:
                # model.load_state_dict(torch.load(model_path))
                model.load_state_dict(
                    torch.load(model_path, map_location=device))

            tp_all = 0
            fp_all = 0
            fn_all = 0
            tn_all = 0

            mean_precision = 0
            mean_recall = 0
            mean_iou = 0
            mean_dice = 0
            mean_F2 = 0
            mean_acc = 0
            mean_spe = 0
            mean_se = 0

            mean_precision_np = 0
            mean_recall_np = 0
            mean_iou_np = 0
            mean_dice_np = 0

            test_fold = "fold" + str(config["dataset"]["fold"])
            logger.info(f"Start testing fold{id} epoch {e}")
            if "visualize_dir" not in config["test"]:
                visualize_dir = "results"
            else:
                visualize_dir = os.path.join(config["test"]["visualize_dir"])
                # visualize_dir = os.path.join(config["test"]["visualize_dir"],config["model"]["arch"], )

            test_fold = "fold" + str(id)
            logger.info(
                f"Start testing {len(test_loader)} images in {dataset} dataset"
            )

            for i, pack in tqdm.tqdm(enumerate(test_loader, start=1)):
                image, gt, filename, img = pack
                name = os.path.splitext(filename[0])[0]
                ext = os.path.splitext(filename[0])[1]
                gt = gt[0][0]
                gt = np.asarray(gt, np.float32).round()
                res2 = 0
                if dev == "cpu":
                    image = image.cpu()
                else:
                    image = image.cuda()

                res5, res4, res3, res2 = model(image)
                # _, _, res5, res4, res3, res2 = model(image)
                # res5_head, res5, res4, res3, res2 = model(image)
                # res2 = model(image)

                res = res2
                res = F.upsample(res,
                                 size=gt.shape,
                                 mode="bilinear",
                                 align_corners=False)
                res = res.sigmoid().data.cpu().numpy().squeeze()
                res = (res - res.min()) / (res.max() - res.min() + 1e-8)

                overwrite = config["test"]["vis_overwrite"]
                vis_x = config["test"]["vis_x"]
                if config["test"]["visualize"]:
                    save_img(
                        os.path.join(
                            visualize_dir,
                            "PR_" + str(arch),
                            "Hard",
                            name + ext,
                        ),
                        res.round() * 255,
                        "cv2",
                        overwrite,
                    )
                    save_img(
                        os.path.join(
                            visualize_dir,
                            "PR_" + str(arch),
                            "Soft",
                            name + ext,
                        ),
                        res * 255,
                        "cv2",
                        overwrite,
                    )
                    # mask_img = np.asarray(img[0]) + cv2.cvtColor(res.round()*60, cv2.COLOR_GRAY2BGR)
                    mask_img = (np.asarray(img[0]) + vis_x * np.array((
                        np.zeros_like(res.round()),
                        res.round(),
                        np.zeros_like(res.round()),
                    )).transpose((1, 2, 0)) + vis_x * np.array(
                        (gt, np.zeros_like(gt), np.zeros_like(gt))).transpose(
                            (1, 2, 0)))
                    mask_img = mask_img[:, :, ::-1]
                    save_img(
                        os.path.join(
                            visualize_dir,
                            "GT_PR_" + str(arch),
                            name + ext,
                        ),
                        mask_img,
                        "cv2",
                        overwrite,
                    )

                pr = res.round()
                prs.append(pr)
                gts.append(gt)

                tp = np.sum(gt * pr)
                fp = np.sum(pr) - tp
                fn = np.sum(gt) - tp
                tn = np.sum((1 - pr) * (1 - gt))

                tp_all += tp
                fp_all += fp
                fn_all += fn
                tn_all += tn

                mean_precision += precision_m(gt, pr)
                mean_recall += recall_m(gt, pr)
                mean_iou += jaccard_m(gt, pr)
                mean_dice += dice_m(gt, pr)
                mean_F2 += (5 * precision_m(gt, pr) * recall_m(gt, pr)) / (
                    4 * precision_m(gt, pr) + recall_m(gt, pr))
                mean_acc += (tp + tn) / (tp + tn + fp + fn)
                mean_se += tp / (tp + fn)
                mean_spe += tn / (tn + fp)

                # pr = res
                # thresh_precision = 0
                # thresh_recall = 0
                # thresh_iou = 0
                # thresh_dice = 0
                # for thresh in np.arange(0, 1,1/256):
                #     out = pr.copy()
                #     out[out<thresh] = 0
                #     out[out>=thresh] = 1
                #     thresh_precision += precision_m(gt, out)
                #     thresh_recall += recall_m(gt, out)
                #     thresh_iou += jaccard_m(gt, out)
                #     thresh_dice += dice_m(gt, out)

                # mean_precision_np += thresh_precision/256
                # mean_recall_np += thresh_recall/256
                # mean_iou_np += thresh_iou/256
                # mean_dice_np += thresh_dice/256

            # mean_precision_np /= len(test_loader)
            # mean_recall_np /= len(test_loader)
            # mean_iou_np /= len(test_loader)
            # mean_dice_np /= len(test_loader)

            # logger.info(
            #     "scores ver1: {:.3f} {:.3f} {:.3f} {:.3f}".format(
            #         mean_iou_np,
            #         mean_precision_np,
            #         mean_recall_np,
            #         mean_dice_np
            #         # , mean_F2
            #     )
            # )

            mean_precision /= len(test_loader)
            mean_recall /= len(test_loader)
            mean_iou /= len(test_loader)
            mean_dice /= len(test_loader)
            mean_F2 /= len(test_loader)
            mean_acc /= len(test_loader)
            mean_se /= len(test_loader)
            mean_spe /= len(test_loader)

            logger.info(
                "scores ver1: {:.3f} {:.3f} {:.3f} {:.3f} {:.3f} {:.3f} {:.3f} {:.3f}"
                .format(
                    mean_iou,
                    mean_precision,
                    mean_recall,
                    mean_F2,
                    mean_se,
                    mean_spe,
                    mean_acc,
                    mean_dice,
                ))

            # logger.info(
            #     "scores ver1: {:.3f} {:.3f} {:.3f} {:.3f}".format(
            #         mean_iou,
            #         mean_precision,
            #         mean_recall,
            #         mean_dice
            #         # , mean_F2
            #     )
            # )

            precision_all = tp_all / (tp_all + fp_all + 1e-07)
            recall_all = tp_all / (tp_all + fn_all + 1e-07)
            dice_all = 2 * precision_all * recall_all / (precision_all +
                                                         recall_all)
            iou_all = (
                recall_all * precision_all /
                (recall_all + precision_all - recall_all * precision_all))
            logger.info("scores ver2: {:.3f} {:.3f} {:.3f} {:.3f}".format(
                iou_all, precision_all, recall_all, dice_all))

    from utils.metrics import get_scores_v1, get_scores_v2

    # if len(folds.keys()) > 1:
    get_scores_v1(gts, prs, logger)
    # get_scores_v2(gts, prs, logger)

    return gts, prs
Ejemplo n.º 2
0
def main():

    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=True,
                        default="configs/default_config.yaml")
    args = parser.parse_args()

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)

    gts = []
    prs = []

    folds = config["test"]["folds"]
    print(folds)
    dataset = config["dataset"]["test_data_path"][0].split("/")[-1]
    if len(folds.keys()) == 1:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_{list(folds.keys())[0]}_{dataset}.log',
            rotation="10 MB",
        )
    else:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_kfold.log',
            rotation="10 MB",
        )

    for id in list(folds.keys()):

        test_img_paths = []
        test_mask_paths = []
        test_data_path = config["dataset"]["test_data_path"]
        for i in test_data_path:
            test_img_paths.extend(glob(os.path.join(i, "*")))
            test_mask_paths.extend(glob(os.path.join(i, "*")))
        test_img_paths.sort()
        test_mask_paths.sort()

        test_transform = None

        test_loader = get_loader(
            test_img_paths,
            test_mask_paths,
            transform=test_transform,
            **config["test"]["dataloader"],
            type="test",
        )
        test_size = len(test_loader)

        epochs = folds[id]
        if type(epochs) != list:
            epochs = [3 * (epochs // 3) + 2]
        elif len(epochs) == 2:
            epochs = [i for i in range(epochs[0], epochs[1])]
            # epochs = [3 * i + 2 for i in range(epochs[0] // 3, (epochs[1] + 1) // 3)]
        elif len(epochs) == 1:
            epochs = [3 * (epochs[0] // 3) + 2]
        else:
            logger.debug("Model path must have 0 or 1 num")
            break
        for e in epochs:
            # MODEL

            logger.info("Loading model")
            model_prams = config["model"]
            import network.models as models

            arch = model_prams["arch"]

            model = models.__dict__[arch]()  # Pranet
            if "save_dir" not in model_prams:
                save_dir = os.path.join("snapshots",
                                        model_prams["arch"] + "_kfold")
            else:
                save_dir = config["model"]["save_dir"]

            model_path = os.path.join(
                save_dir,
                f"PraNetDG-fold{id}-{e}.pth",
            )

            device = torch.device("cpu")
            # model.cpu()

            model.cuda()
            model.eval()

            logger.info(f"Loading from {model_path}")
            try:
                model.load_state_dict(
                    torch.load(model_path)["model_state_dict"])
            except RuntimeError:
                model.load_state_dict(torch.load(model_path))

            test_fold = "fold" + str(config["dataset"]["fold"])
            logger.info(f"Start testing fold{id} epoch {e}")
            if "visualize_dir" not in config["test"]:
                visualize_dir = "results"
            else:
                visualize_dir = config["test"]["visualize_dir"]

            test_fold = "fold" + str(id)
            logger.info(
                f"Start testing {len(test_loader)} images in {dataset} dataset"
            )
            vals = AvgMeter()
            H, W, T = 240, 240, 155

            for i, pack in tqdm.tqdm(enumerate(test_loader, start=1)):
                image, gt, filename, img = pack
                name = os.path.splitext(filename[0])[0]
                ext = os.path.splitext(filename[0])[1]
                # print(gt.shape,image.shape,"ppp")
                # import sys
                # sys.exit()
                gt = gt[0]
                gt = np.asarray(gt, np.float32)
                res2 = 0
                image = image.cuda()

                res5, res4, res3, res2 = model(image)

                # res = res2
                # res = F.upsample(
                #     res, size=gt.shape, mode="bilinear", align_corners=False
                # )
                # res = res.sigmoid().data.cpu().numpy().squeeze()
                # res = (res - res.min()) / (res.max() - res.min() + 1e-8)
                output = res2[0, :, :H, :W, :T].cpu().detach().numpy()
                output = output.argmax(
                    0
                )  # (num_classes,height,width,depth) num_classes is now one-hot

                target_cpu = gt[:H, :W, :T].numpy()
                scores = softmax_output_dice(output, target_cpu)
                vals.update(np.array(scores))
                # msg += ', '.join(['{}: {:.4f}'.format(k, v) for k, v in zip(keys, scores)])

                seg_img = np.zeros(shape=(H, W, T), dtype=np.uint8)

                # same as res.round()
                seg_img[np.where(output == 1)] = 1
                seg_img[np.where(output == 2)] = 2
                seg_img[np.where(output == 3)] = 4
                # if verbose:
                logger.info(
                    f'1:{np.sum(seg_img==1)} | 2: {np.sum(seg_img==2)} | 4: {np.sum(seg_img==4)}'
                )
                logger.info(
                    f'WT: {np.sum((seg_img==1)|(seg_img==2)|(seg_img==4))} | TC: {np.sum((seg_img==1)|(seg_img==4))} | ET: {np.sum(seg_img==4)}'
                )

                overwrite = config["test"]["vis_overwrite"]
                vis_x = config["test"]["vis_x"]
                if config["test"]["visualize"]:
                    oname = os.path.join(visualize_dir, 'submission',
                                         name[:-8] + '_pred.nii.gz')
                    save_img(
                        oname,
                        seg_img,
                        "nib",
                        overwrite,
                    )
            logger.info(vals.avg)
Ejemplo n.º 3
0
def main():
    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=True,
                        default="configs/default_config.yaml")
    args = parser.parse_args()

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)
    logger.add(
        f'logs/train_{config["model"]["arch"]}_{str(datetime.now())}_{config["dataset"]["fold"]}.log',
        rotation="10 MB",
    )

    logger.info(f"Load config from {config_path}")
    logger.info(f"{config}")

    # GET_DATA_PATH
    logger.info("Getting datapath")

    # For 3d .nii
    train_img_paths = []
    train_mask_paths = []
    train_data_path = config["dataset"]["train_data_path"]
    for i in train_data_path:
        train_img_paths.extend(glob(os.path.join(i, "*")))
        train_mask_paths.extend(glob(os.path.join(i, "*")))
    train_img_paths.sort()
    train_mask_paths.sort()
    logger.info(f"There are {len(train_img_paths)} images to train")

    # For 3d .nii
    val_img_paths = []
    val_mask_paths = []
    val_data_path = config["dataset"]["val_data_path"]
    for i in val_data_path:
        val_img_paths.extend(glob(os.path.join(i, "*")))
        val_mask_paths.extend(glob(os.path.join(i, "*")))
    val_img_paths.sort()
    val_mask_paths.sort()
    logger.info(f'There are {len(val_mask_paths)} images to val')

    # DATALOADER
    logger.info("Loading data")
    train_augprams = config["train"]["augment"]
    # 3d .nii no need augment
    train_transform = train_augprams["train_transforms"]

    train_loader = get_loader(
        train_img_paths,
        train_mask_paths,
        transform=train_transform,
        **config["train"]["dataloader"],
        type="train",
    )
    total_step = len(train_loader)

    logger.info(f"{total_step} batches to train")

    val_augprams = config["test"]["augment"]

    # 3d .nii no need augment
    # val_transform = Augmenter(**val_augprams)
    val_transform = None

    val_loader = get_loader(
        val_img_paths,
        val_mask_paths,
        transform=val_transform,
        **config["test"]["dataloader"],
        type="val",
    )
    val_size = len(val_loader)

    # USE MODEL
    logger.info("Loading model")
    model_prams = config["model"]
    if "save_dir" not in model_prams:
        save_dir = os.path.join("snapshots", model_prams["arch"] + "_kfold")
    else:
        save_dir = config["model"]["save_dir"]

    import network.models as models

    model = models.__dict__[model_prams["arch"]]()  # Pranet
    model = model.cuda()
    params = model.parameters()

    # USE OPTIMIZER
    opt_params = config["optimizer"]
    import network.optim.optimizers as optims

    lr = opt_params["lr"]
    optimizer = optims.__dict__[opt_params["name"].lower()](params, lr / 8)

    # USE SCHEDULE
    import network.optim.schedulers as schedulers

    scheduler = schedulers.__dict__[opt_params["scheduler"]](
        optimizer, model_prams["num_epochs"], opt_params["num_warmup_epoch"])
    # scheduler = None
    # USE LOSS
    import network.optim.losses as losses

    loss = losses.__dict__[opt_params["loss"]]()

    # TRAINER
    fold = config["dataset"]["fold"]
    logger.info("#" * 20 + f"Start Training Fold {fold}" + "#" * 20)
    from network.models import Trainer3D

    trainer = Trainer3D(model, optimizer, loss, scheduler, save_dir,
                        model_prams["save_from"], logger)

    trainer.fit(
        train_loader=train_loader,
        is_val=config["train"]["is_val"],
        val_loader=val_loader,
        img_size=config["train"]["dataloader"]["img_size"],
        start_from=model_prams["start_from"],
        num_epochs=model_prams["num_epochs"],
        batchsize=config["train"]["dataloader"]["batchsize"],
        fold=fold,
        size_rates=config["train"]["size_rates"],
    )
Ejemplo n.º 4
0
def main():

    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=True,
                        default="configs/default_config.yaml")
    args = parser.parse_args()

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)

    folds = config["test"]["folds"]
    print(folds)
    dataset = config["dataset"]["test_data_path"][0].split("/")[-1]
    if len(folds.keys()) == 1:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_{list(folds.keys())[0]}_{dataset}.log',
            rotation="10 MB",
        )
    else:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_kfold.log',
            rotation="10 MB",
        )

    for id in list(folds.keys()):

        # FOR ORIDATASET
        test_img_paths = []
        test_mask_paths = []
        test_data_path = config["dataset"]["test_data_path"]
        for i in test_data_path:
            test_img_paths.extend(glob(os.path.join(i, "images", "*")))
            test_mask_paths.extend(glob(os.path.join(i, "images", "*")))

        test_img_paths.sort()
        test_mask_paths.sort()

        test_augprams = config["test"]["augment"]
        test_transform = Augmenter(**test_augprams)
        test_loader = get_loader(
            test_img_paths,
            test_mask_paths,
            transform=test_transform,
            **config["test"]["dataloader"],
            type="private",
        )
        test_size = len(test_loader)

        epochs = folds[id]
        if type(epochs) != list:
            epochs = [3 * (epochs // 3) + 2]
        elif len(epochs) == 2:
            epochs = [i for i in range(epochs[0], epochs[1])]
            # epochs = [3 * i + 2 for i in range(epochs[0] // 3, (epochs[1] + 1) // 3)]
        elif len(epochs) == 1:
            epochs = [3 * (epochs[0] // 3) + 2]
        else:
            logger.debug("Model path must have 0 or 1 num")
            break
        for e in epochs:
            # MODEL

            logger.info("Loading model")
            model_prams = config["model"]
            import network.models as models

            arch = model_prams["arch"]
            model = models.__dict__[arch]()  # Pranet
            if "save_dir" not in model_prams:
                save_dir = os.path.join("snapshots",
                                        model_prams["arch"] + "_kfold")
            else:
                save_dir = config["model"]["save_dir"]

            model_path = os.path.join(
                save_dir,
                f"PraNetDG-fold{id}-{e}.pth",
            )
            model.cuda()
            model.eval()

            logger.info(f"Loading from {model_path}")
            try:
                model.load_state_dict(
                    torch.load(model_path)["model_state_dict"])
                # model.load_state_dict(torch.load(model_path,map_location=device)["model_state_dict"])
            except RuntimeError:
                model.load_state_dict(torch.load(model_path))
                # model.load_state_dict(torch.load(model_path,map_location=device))

            test_fold = "fold" + str(config["dataset"]["fold"])
            logger.info(f"Start testing fold{id} epoch {e}")
            if "visualize_dir" not in config["test"]:
                visualize_dir = "results"
            else:
                visualize_dir = config["test"]["visualize_dir"]

            test_fold = "fold" + str(id)
            logger.info(
                f"Start testing {len(test_loader)} images in {dataset} dataset"
            )

            for i, pack in tqdm.tqdm(enumerate(test_loader, start=1)):
                image, filename, org_image = pack
                name = os.path.splitext(filename[0])[0]
                ext = os.path.splitext(filename[0])[1]

                overwrite = config["test"]["vis_overwrite"]
                if not overwrite and os.path.exists(
                        os.path.join(
                            visualize_dir,
                            test_fold,
                            str(arch),
                            name + "_segmentation" + ".png",
                        )
                        # and os.path.exists(
                        #     os.path.join(
                        #         visualize_dir,
                        #         test_fold,
                        #         "mask_" + str(arch),
                        #         name + "_mask_pr" + str(arch) + ext,
                        #     )
                        # )
                ):
                    continue

                res2 = 0
                image = image.cuda()
                res5, res4, res3, res2 = model(image)

                res = res2
                # print(org_image.shape[1:3])
                # import sys

                # sys.exit()

                res = F.upsample(res,
                                 size=org_image.shape[1:3],
                                 mode="bilinear",
                                 align_corners=False)
                res = res.sigmoid().data.cpu().numpy().squeeze()
                res = (res - res.min()) / (res.max() - res.min() + 1e-8)

                vis_x = config["test"]["vis_x"]
                if config["test"]["visualize"]:
                    save_img(
                        os.path.join(
                            visualize_dir,
                            test_fold,
                            str(arch),
                            name + "_segmentation" + ".png",
                        ),
                        res.round() * 255,
                        "cv2",
                        overwrite,
                    )
Ejemplo n.º 5
0
def main():
    MAX_IMAGE_SHOW = 10
    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=False,
                        default="configs/gcee_isicbcdu_config.yaml")
    args = parser.parse_args()

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)

    folds = config["test"]["folds"]
    print(folds)
    # dataset = config["dataset"]["test_data_path"][0].split("/")[-1]
    dev = config["test"]["dev"]

    # img_paths = glob.glob("/mnt/data/hungnt/data/kvasir-seg/TestDataset/Kvasir/images/*")
    # img_paths = glob.glob("data/kvasir-instrument/testdataset/images/*")
    # img_paths = glob.glob("data/Kvasir_SEG/Kvasir_SEG_Validation_120/images/*")
    # img_paths = glob.glob("/mnt/data/hungnt/data/CHASE_OFF/test/images/*")

    # mask_paths = glob.glob("/mnt/data/hungnt/data/kvasir-seg/TestDataset/Kvasir/masks/*")
    # mask_paths = glob.glob("data/kvasir-instrument/testdataset/masks/*")
    # mask_paths = glob.glob("data/Kvasir_SEG/Kvasir_SEG_Validation_120/masks/*")
    # mask_paths = glob.glob("/mnt/data/hungnt/data/CHASE_OFF/test/masks/*")

    if type(config["infer"]["mask_paths"]) != list:
        mask_paths = glob.glob(os.path.join(config["infer"]["mask_paths"],
                                            "*"))
    else:
        mask_paths = config["infer"]["mask_paths"]
    if type(config["infer"]["img_paths"]) != list:
        img_paths = glob.glob(os.path.join(config["infer"]["img_paths"], "*"))
    else:
        img_paths = config["infer"]["img_paths"]

    # mask_path = ''
    img_paths.sort()
    mask_paths.sort()

    import network.models as models
    img_size = (config["test"]["dataloader"]["img_size"],
                config["test"]["dataloader"]["img_size"])
    arch_path = config["infer"]["models"]
    numpy_vertical = []
    import matplotlib.pyplot as plt
    print(len(img_paths), len(arch_path) + 2, len(mask_paths))
    # import sys
    # sys.exit()
    if (len(mask_paths) < MAX_IMAGE_SHOW):
        fig, axs = plt.subplots(len(arch_path) + 2,
                                len(mask_paths),
                                constrained_layout=True,
                                figsize=(15, 15))
        cols = [os.path.basename(i) for i in mask_paths]
        rows = ["Images", "GT"]
        # rows = []
        [rows.append(i[0]) for i in arch_path]
        print(rows)
        for ax, col in zip(axs[0], cols):
            ax.set_title(col, fontsize=20)
        for ax, row in zip(axs[:, 0], rows):
            ax.set_ylabel(row, rotation="horizontal", fontsize=20)
    c = -1
    for (arch, model_path) in arch_path:
        c += 1
        model = models.__dict__[arch]()

        if dev == "cpu":
            model.cpu()
        else:
            model.cuda()
        model.eval()
        logger.info(f"Loading from {model_path}")
        device = torch.device(dev)

        try:
            # model.load_state_dict(torch.load(model_path)["model_state_dict"])
            model.load_state_dict(
                torch.load(model_path,
                           map_location=device)["model_state_dict"])
        except:
            # model.load_state_dict(torch.load(model_path))
            model.load_state_dict(torch.load(model_path, map_location=device))

        mask_img_gt = []
        soft_ress = []
        ress = []
        mask_img_gt_pr = []
        imgs = []
        mean_dices = []
        mean_F2 = []
        mean_precisions = []
        mean_recalls = []
        mean_ious = []
        mean_accs = []
        mean_ses = []
        mean_spes = []
        tps = []
        fps = []
        fns = []
        tns = []

        cc = -1
        for img_path, mask_path in zip(img_paths, mask_paths):
            cc += 1
            image_ = imread(img_path)  # h, w , 3 (0-255), numpy
            if os.path.exists(mask_path):
                mask = np.array(Image.open(mask_path).convert("L"))
                # mask = mask / 255

            else:
                print("not exist")
                mask = np.zeros(image_.shape[:2], dtype=np.float64)

            # if os.path.splitext(os.path.basename(img_path))[0].isnumeric():

            image = cv2.resize(image_, img_size)
            image = image.astype("float32") / 255
            image = image.transpose((2, 0, 1))
            image = image[:, :, :, np.newaxis]
            image = image.transpose((3, 0, 1, 2))

            mask = mask.astype("float32")

            image, gt, filename, img = (
                np.asarray(image),
                np.asarray(mask),
                os.path.basename(img_path),
                np.asarray(image_),
            )

            name = os.path.splitext(filename)[0]
            ext = os.path.splitext(filename)[1]
            gt = np.asarray(gt, np.float32)
            gt /= gt.max() + 1e-8
            res2 = 0
            # image = torch.tensor(image).float().cuda()
            if dev == "cpu":
                image = torch.tensor(image).float()
            else:
                image = torch.tensor(image).float().cuda()

            # image = image.cpu()
            if arch == "UNet":
                res2 = model(image)
            else:
                res5, res4, res3, res2 = model(image)
            res = res2
            res = F.upsample(res,
                             size=gt.shape,
                             mode="bilinear",
                             align_corners=False)
            res = res.sigmoid().data.cpu().numpy().squeeze()
            res = (res - res.min()) / (res.max() - res.min() + 1e-8)

            pr = res.round()
            tp = np.sum(gt * pr)
            fp = np.sum(pr) - tp
            fn = np.sum(gt) - tp
            tn = np.sum((1 - pr) * (1 - gt))
            tps.append(tp / (tp + fn))
            fps.append(fp / (fp + tn))
            fns.append(fn / (fn + tp))
            tns.append(tn / (tn + fp))

            mean_acc = (tp + tn) / (tp + tn + fp + fn)

            mean_se = tp / (tp + fn)
            mean_spe = tn / (tn + fp)

            mean_precision = precision_m(gt, pr)
            mean_recall = recall_m(gt, pr)
            mean_iou = jaccard_m(gt, pr)
            mean_dice = dice_m(gt, pr)
            mean_F2 = (5 * precision_m(gt, pr) * recall_m(gt, pr)) / (
                4 * precision_m(gt, pr) + recall_m(gt, pr))
            # mean_acc += (tp+tn)/(tp+tn+fp+fn)
            logger.info("scores ver1: {:.3f} {:.3f} {:.3f} {:.3f}".format(
                mean_iou, mean_precision, mean_recall, mean_dice
                # , mean_F2
            ))
            mean_ious.append(mean_iou)
            mean_precisions.append(mean_precision)
            mean_recalls.append(mean_recall)
            mean_dices.append(mean_dice)
            mean_accs.append(mean_dice)
            mean_ses.append(mean_dice)
            mean_spes.append(mean_dice)

            precision_all = tp / (tp + fp + 1e-07)
            recall_all = tp / (tp + fn + 1e-07)
            dice_all = 2 * precision_all * recall_all / (precision_all +
                                                         recall_all)
            iou_all = (
                recall_all * precision_all /
                (recall_all + precision_all - recall_all * precision_all))
            logger.info("scores ver2: {:.3f} {:.3f} {:.3f} {:.3f}".format(
                iou_all, precision_all, recall_all, dice_all))

            overwrite = config["infer"]["vis_overwrite"]
            vis_x = config["infer"]["vis_x"]
            if "visualize_dir" not in config["infer"]:
                visualize_dir = "outputs/infer"
            else:
                visualize_dir = config["infer"]["visualize_dir"]
            if not os.path.exists(visualize_dir):
                os.makedirs(visualize_dir)

            ##### IMAGE
            imgs.append(np.asarray(img[:, :, ::-1]))
            if (len(mask_paths) < MAX_IMAGE_SHOW):
                axs[0][cc].imshow(img / (img.max() + 1e-8), cmap="gray")
                axs[0][cc].set_axis_off()

            ##### HARD PR
            ress.append(res.round() * 255)
            save_img(
                os.path.join(
                    visualize_dir,
                    str(arch),
                    name + "_pr" + str(arch) + ext,
                ),
                res.round() * 255,
                "cv2",
                overwrite,
            )

            mask_img = np.asarray(img) + vis_x * np.array(
                (gt, np.zeros_like(gt), np.zeros_like(gt))
                # (np.zeros_like(gt), gt, np.zeros_like(gt))
                # (gt, gt, np.zeros_like(gt))
            ).transpose((1, 2, 0))
            if (len(mask_paths) < MAX_IMAGE_SHOW):
                axs[1][cc].imshow(gt, cmap="gray")
                axs[1][cc].set_axis_off()

            mask_img = mask_img[:, :, ::-1]

            ##### MASK GT
            mask_img_gt.append(mask_img)
            save_img(
                os.path.join(
                    visualize_dir,
                    str(arch),
                    name + "_gt" + str(arch) + ext,
                ),
                mask_img,
                "cv2",
                overwrite,
            )

            ##### SOFT PR
            soft_ress.append(res * 255)
            # axs[c + 2][cc].imshow(res * 255, cmap="gray")
            # axs[c + 2][cc].set_axis_off()

            save_img(
                os.path.join(
                    visualize_dir,
                    "soft_" + str(arch),
                    name + "_soft_pr" + str(arch) + ext,
                ),
                res * 255,
                "cv2",
                overwrite,
            )

            mask_img = (
                np.asarray(img) + vis_x * np.array((
                    np.zeros_like(res.round()),
                    res.round(),
                    np.zeros_like(res.round()),
                )).transpose((1, 2, 0)) + vis_x * np.array(
                    (gt, np.zeros_like(gt), np.zeros_like(gt))
                    # (gt, gt, np.zeros_like(gt))
                ).transpose((1, 2, 0)))
            if (len(mask_paths) < MAX_IMAGE_SHOW):
                axs[c + 2][cc].imshow(res.round(), cmap='gray')
                axs[c + 2][cc].set_axis_off()
            mask_img = mask_img[:, :, ::-1]

            ##### MASK GT_PR

            mask_img_gt_pr.append(mask_img)
            save_img(
                os.path.join(
                    visualize_dir,
                    "mask_" + str(arch),
                    name + "mask_pr" + str(arch) + ext,
                ),
                mask_img,
                "cv2",
                overwrite,
            )
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()
    if (len(mask_paths) < 10):
        fig.savefig(os.path.join(visualize_dir,
                                 config["infer"]["compare_fig"]))

    s = [
        list(a) for a in zip(
            [os.path.basename(o) for o in mask_paths],
            mean_dices,
            mean_ious,
            mean_precisions,
            mean_recalls,
            mean_accs,
            mean_ses,
            mean_spes,
            tps,
            fps,
            fns,
            tns,
        )
    ]
    s.sort(key=lambda x: x[1])
    for i in s:
        logger.info(i)

    import pandas as pd
    pd.DataFrame(s,
                 columns=[
                     "name", "mean_dices", "mean_ious", "mean_precisions",
                     "mean_recalls", "mean_accs", "mean_ses", "mean_spes",
                     "tps", "fps", "fns", "tns"
                 ]).to_csv(
                     os.path.join(visualize_dir,
                                  config["infer"]["compare_csv"]))
Ejemplo n.º 6
0
def main():

    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=True,
                        default="configs/default_config.yaml")
    args = parser.parse_args()

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)

    gts = []
    prs = []
    dev = config["test"]["dev"]

    folds = config["test"]["folds"]
    print(folds)
    dataset = config["dataset"]["test_data_path"][0].split("/")[-1]
    if len(folds.keys()) == 1:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_{list(folds.keys())[0]}_{dataset}.log',
            rotation="10 MB",
        )
    else:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_kfold.log',
            rotation="10 MB",
        )

    for id in list(folds.keys()):

        # FOR ORIDATASET
        test_img_paths = []
        test_mask_paths = []
        test_data_path = config["dataset"]["test_data_path"]
        for i in test_data_path:
            x = glob(os.path.join(i, "*"))
            # x = [os.path.join(i,"images",os.path.basename(i)+".png") for i in x]
            test_img_paths.extend([
                os.path.join(i, "images",
                             os.path.basename(i) + ".png") for i in x
            ])
            test_mask_paths.extend([
                os.path.join(i, "images",
                             os.path.basename(i) + ".png") for i in x
            ])

        test_img_paths.sort()
        test_mask_paths.sort()

        test_augprams = config["test"]["augment"]
        test_transform = Augmenter(**test_augprams)
        test_loader = get_loader(
            test_img_paths,
            test_mask_paths,
            transform=test_transform,
            **config["test"]["dataloader"],
            type="test",
        )
        test_size = len(test_loader)

        epochs = folds[id]
        if type(epochs) != list:
            epochs = [3 * (epochs // 3) + 2]
        elif len(epochs) == 2:
            epochs = [i for i in range(epochs[0], epochs[1])]
            # epochs = [3 * i + 2 for i in range(epochs[0] // 3, (epochs[1] + 1) // 3)]
        elif len(epochs) == 1:
            epochs = [3 * (epochs[0] // 3) + 2]
        else:
            logger.debug("Model path must have 0 or 1 num")
            break
        for e in epochs:
            # MODEL

            logger.info("Loading model")
            model_prams = config["model"]
            import network.models as models

            arch = model_prams["arch"]

            model = models.__dict__[arch]()  # Pranet
            if "save_dir" not in model_prams:
                save_dir = os.path.join("snapshots",
                                        model_prams["arch"] + "_kfold")
            else:
                save_dir = config["model"]["save_dir"]

            model_path = os.path.join(
                save_dir,
                f"PraNetDG-fold{id}-{e}.pth",
            )

            device = torch.device(dev)
            if dev == "cpu":
                model.cpu()
            else:
                model.cuda()
            model.eval()

            logger.info(f"Loading from {model_path}")
            try:

                # model.load_state_dict(torch.load(model_path))
                # model.load_state_dict(torch.load(model_path)["model_state_dict"])
                model.load_state_dict(
                    torch.load(model_path,
                               map_location=device)["model_state_dict"])
            except RuntimeError:
                # model.load_state_dict(torch.load(model_path))
                model.load_state_dict(
                    torch.load(model_path, map_location=device))

            test_fold = "fold" + str(config["dataset"]["fold"])
            logger.info(f"Start testing fold{id} epoch {e}")
            if "visualize_dir" not in config["test"]:
                visualize_dir = "results"
            else:
                visualize_dir = config["test"]["visualize_dir"]

            test_fold = "fold" + str(id)
            logger.info(
                f"Start testing {len(test_loader)} images in {dataset} dataset"
            )

            import pandas as pd
            from utils.utils import rle_encoding, prob_to_rles

            submission_df = pd.DataFrame(columns=["ImageId", "EncodedPixels"])

            preds_upsampled = []
            new_test_ids = []
            rles = []
            test_files = []

            for i, pack in tqdm.tqdm(enumerate(test_loader, start=0)):
                image, gt, filename, img = pack
                name = os.path.splitext(filename[0])[0]
                ext = os.path.splitext(filename[0])[1]
                gt = gt[0][0]
                gt = np.asarray(gt, np.float32)
                res2 = 0
                if dev == "cpu":
                    image = image.cpu()
                else:
                    image = image.cuda()
                # image = image.cpu()

                res5, res4, res3, res2 = model(image)
                res = res2
                res = F.upsample(res,
                                 size=gt.shape,
                                 mode="bilinear",
                                 align_corners=False)
                res = res.sigmoid().data.cpu().numpy().squeeze()
                res = (res - res.min()) / (res.max() - res.min() + 1e-8)
                pr = res.round()
                save_img(
                    os.path.join(
                        visualize_dir,
                        "PR_" + str(arch),
                        "Soft",
                        name + ext,
                    ),
                    1 * (res > 0.5) * res * 255,
                    # res * 255,
                    "cv2",
                    True,
                )

            #     preds_upsampled.append(res)
            #     test_files.append(name)

            #     encoding = rle_encoding(res)
            #     print(encoding)
            #     pixels = " ".join(map(str, encoding))
            #     submission_df.loc[i] = [name, pixels]
            # rles = []
            # for n, id_ in enumerate(test_files):
            #     rle = list(prob_to_rles(preds_upsampled[n]))
            #     rles.extend(rle)
            #     new_test_ids.extend([id_] * len(rle))
            # sub = pd.DataFrame()
            # sub['ImageId'] = new_test_ids
            # sub['EncodedPixels'] = pd.Series(rles).apply(lambda x: ' '.join(str(y) for y in x))
            # sub.to_csv('sub-dsbowl2018.csv', index=False)

    return gts, prs
Ejemplo n.º 7
0
def main():
    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=True,
                        default="configs/default_config.yaml")
    args = parser.parse_args()

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)

    logger.add(f"logs/{str(datetime.now())}_train_log_file.log",
               rotation="10 MB")
    logger.info(f"Load config from {config_path}")
    logger.info(f"{config}")

    # GET_DATA_PATH
    logger.info("Getting datapath")
    train_img_paths = []
    train_mask_paths = []
    train_data_path = config["dataset"]["train_data_path"]
    for i in train_data_path:
        train_img_paths.extend(glob(os.path.join(i, "images", "*")))
        train_mask_paths.extend(glob(os.path.join(i, "masks", "*")))
    train_img_paths.sort()
    train_mask_paths.sort()

    test_img_paths = []
    test_mask_paths = []
    test_data_path = config["dataset"]["test_data_path"]
    for i in test_data_path:
        test_img_paths.extend(glob(os.path.join(i, "images", "*")))
        test_mask_paths.extend(glob(os.path.join(i, "masks", "*")))
    test_img_paths.sort()
    test_mask_paths.sort()

    # DATALOADER
    logger.info("Loading data")
    train_augprams = config["train"]["augment"]
    train_transform = Augmenter(**train_augprams)
    train_loader = get_loader(
        train_img_paths,
        train_mask_paths,
        transform=train_transform,
        **config["train"]["dataloader"],
        is_train=True,
    )
    total_step = len(train_loader)

    test_augprams = config["test"]["augment"]
    test_transform = Augmenter(**test_augprams)
    test_loader = get_loader(
        test_img_paths,
        test_mask_paths,
        transform=test_transform,
        **config["test"]["dataloader"],
        is_train=False,
    )
    test_size = len(test_loader)

    # USE MODEL
    logger.info("Loading model")
    model_prams = config["model"]
    save_dir = os.path.join(model_prams["save_dir"], model_prams["arch"])

    n_skip = 3
    vit_name = "R50-ViT-B_16"
    vit_patches_size = 16
    img_size = config["dataset"]["img_size"]
    import torch.backends.cudnn as cudnn
    from network.models.transunet.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
    import numpy as np

    config_vit = CONFIGS_ViT_seg[vit_name]
    config_vit.n_classes = 1
    config_vit.n_skip = n_skip
    if vit_name.find("R50") != -1:
        config_vit.patches.grid = (
            int(img_size / vit_patches_size),
            int(img_size / vit_patches_size),
        )

    import network.models as models

    # model = models.__dict__[model_prams["arch"]]()  # Pranet

    model = models.__dict__[model_prams["arch"]](
        config_vit, img_size=img_size,
        num_classes=config_vit.n_classes)  # TransUnet
    model = model.cuda()
    model.load_from(weights=np.load(config_vit.pretrained_path))

    # Pranet
    # if model_prams["start_from"] != 0:
    #     restore_from = os.path.join(save_dir,f'PraNetDG-fold{config["dataset"]["fold"]}-{model_prams["start_from"]}.pth')
    #     lr = model.initialize_weights(restore_from)

    params = model.parameters()

    # USE OPTIMIZER
    opt_params = config["optimizer"]
    import network.optim.optimizers as optims

    lr = opt_params["lr"]
    optimizer = optims.__dict__[opt_params["name"].lower()](params, lr / 8)

    # USE SCHEDULE
    import network.optim.schedulers as schedulers

    scheduler = schedulers.__dict__[opt_params["scheduler"]](
        optimizer, lr, model_prams["num_epochs"],
        opt_params["num_warmup_epoch"])

    # USE LOSS
    import network.optim.losses as losses

    loss = losses.__dict__[opt_params["loss"]]()

    # TRAINER
    fold = config["dataset"]["fold"]
    logger.info("#" * 20, f"Start Training Fold {fold}", "#" * 20)
    from network.models import Trainer, TransUnetTrainer

    trainer = TransUnetTrainer(model, optimizer, loss, scheduler, save_dir,
                               model_prams["save_from"], logger)
    trainer.fit(
        train_loader=train_loader,
        is_val=config["train"]["is_val"],
        test_loader=test_loader,
        img_size=config["train"]["dataloader"]["img_size"],
        start_from=model_prams["start_from"],
        num_epochs=model_prams["num_epochs"],
        batchsize=config["train"]["dataloader"]["batchsize"],
        fold=fold,
    )
Ejemplo n.º 8
0
def main():
    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=False,
                        default="configs/scwc_isic_bcdu.yaml")
    args = parser.parse_args()

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)
    logger.add(
        f'logs/train_{config["model"]["arch"]}_{str(datetime.now())}_{config["dataset"]["fold"]}.log',
        rotation="10 MB",
    )

    logger.info(f"Load config from {config_path}")
    logger.info(f"{config}")

    # GET_DATA_PATH
    logger.info("Getting datapath")

    train_img_paths = []
    train_mask_paths = []
    train_data_path = config["dataset"]["train_data_path"]
    if type(train_data_path) != list:
        train_img_paths = os.path.join(train_data_path, "training.tif")
        train_mask_paths = os.path.join(train_data_path,
                                        "training_groundtruth.tif")
    else:
        for i in train_data_path:
            train_img_paths.extend(glob(os.path.join(i, "images", "*")))
            train_mask_paths.extend(glob(os.path.join(i, "masks", "*")))
        train_img_paths.sort()
        train_mask_paths.sort()
        logger.info(f"There are {len(train_img_paths)} images to train")

    val_img_paths = []
    val_mask_paths = []
    val_data_path = config["dataset"]["val_data_path"]
    if type(val_data_path) != list:
        val_img_paths = os.path.join(val_data_path, "testing.tif")
        val_mask_paths = os.path.join(val_data_path, "testing_groundtruth.tif")
    else:
        for i in val_data_path:
            val_img_paths.extend(glob(os.path.join(i, "images", "*")))
            val_mask_paths.extend(glob(os.path.join(i, "masks", "*")))
        val_img_paths.sort()
        val_mask_paths.sort()
        logger.info(f"There are {len(val_mask_paths)} images to val")

    # DATALOADER
    # print(train_img_paths, val_img_paths)
    # import sys
    # sys.exit()

    logger.info("Loading data")
    train_augprams = config["train"]["augment"]
    train_transform = Augmenter(**train_augprams)
    train_loader = get_loader(
        train_img_paths,
        train_mask_paths,
        transform=train_transform,
        **config["train"]["dataloader"],
        type="train",
    )
    logger.info(f"{len(train_loader)} batches to train")

    val_augprams = config["test"]["augment"]
    val_transform = Augmenter(**val_augprams)
    val_loader = get_loader(
        val_img_paths,
        val_mask_paths,
        transform=val_transform,
        **config["test"]["dataloader"],
        type="val",
    )
    # USE MODEL
    logger.info("Loading model")
    model_prams = config["model"]
    if "save_dir" not in model_prams:
        save_dir = os.path.join("snapshots", model_prams["arch"] + "_kfold")
    else:
        save_dir = config["model"]["save_dir"]

    # FOR TRANSUNET
    # n_skip = 3
    # vit_name = "R50-ViT-B_16"
    # vit_patches_size = 16
    # img_size = config["dataset"]["img_size"]
    # import torch.backends.cudnn as cudnn
    # from network.models.transunet.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
    # import numpy as np

    # config_vit = CONFIGS_ViT_seg[vit_name]
    # config_vit.n_classes = 1
    # config_vit.n_skip = n_skip
    # if vit_name.find("R50") != -1:
    #     config_vit.patches.grid = (
    #         int(img_size / vit_patches_size),
    #         int(img_size / vit_patches_size),
    #     )
    # FOR TRANSUNET

    import network.models as models

    model = models.__dict__[model_prams["arch"]]()  # Pranet
    # model = models.__dict__[model_prams["arch"]](
    #     config_vit, img_size=img_size
    # )  # TransUnet
    model = model.cuda()

    # LOAD PRETRAIN

    # TransUnet
    # model.load_from(weights=np.load(config_vit.pretrained_path))

    # Pranet
    if model_prams["start_from"] != 0:
        restore_from = os.path.join(
            save_dir,
            f'PraNetDG-fold{config["dataset"]["fold"]}-{model_prams["start_from"]}.pth',
        )
        # lr = model.initialize_weights(restore_from)
        saved_state_dict = torch.load(restore_from)["model_state_dict"]
        lr = torch.load(restore_from)["lr"]
        model.load_state_dict(saved_state_dict, strict=False)

    params = model.parameters()

    # USE OPTIMIZER
    opt_params = config["optimizer"]
    import network.optim.optimizers as optims

    lr = opt_params["lr"]
    optimizer = optims.__dict__[opt_params["name"].lower()](params, lr / 8)

    # USE SCHEDULE
    import network.optim.schedulers as schedulers

    scheduler = schedulers.__dict__[opt_params["scheduler"]](
        optimizer, model_prams["num_epochs"], opt_params["num_warmup_epoch"])
    # scheduler = None
    # USE LOSS
    import network.optim.losses as losses

    loss = losses.__dict__[opt_params["loss"]]()

    # TRAINER
    fold = config["dataset"]["fold"]
    logger.info("#" * 20 + f"Start Training Fold {fold}" + "#" * 20)
    from network.models import (
        Trainer,
        TransUnetTrainer,
        TrainerGCPAGALD,
        TrainerSCWS,
        Trainer3D,
        TrainerOne,
    )

    #  TrainerSCWS

    trainer = Trainer(model, optimizer, loss, scheduler, save_dir,
                      model_prams["save_from"], logger)

    trainer.fit(
        train_loader=train_loader,
        is_val=config["train"]["is_val"],
        val_loader=val_loader,
        img_size=config["train"]["dataloader"]["img_size"],
        start_from=model_prams["start_from"],
        num_epochs=model_prams["num_epochs"],
        batchsize=config["train"]["dataloader"]["batchsize"],
        fold=fold,
        size_rates=config["train"]["size_rates"],
    )
Ejemplo n.º 9
0
def main():

    parser = ArgumentParser()
    parser.add_argument(
        "-c", "--config", required=True, default="configs/default_config.yaml"
    )
    args = parser.parse_args()

    logger.add(f"logs/{str(datetime.now())}_test_log_file.log", rotation="10 MB")

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)

    gts = []
    prs = []

    folds = config["test"]["folds"]

    for id in list(folds.keys()):

        epochs = folds[id]
        if type(epochs) != list:
            epochs = [epochs]
        elif len(epochs) == 2:
            epochs = [3 * i + 2 for i in range(epochs[0] // 3, (epochs[1] + 1) // 3)]
        else:
            logger.debug("Model path must have 0 or 1 num")
            epochs = [3 * (epochs[0] // 3) + 2]
        for e in epochs:

            test_img_paths = []
            test_mask_paths = []
            data_path = config["dataset"]["data_path"]
            test_img_paths = glob(os.path.join(data_path, f"fold_{id}", "images", "*"))
            test_mask_paths = glob(os.path.join(data_path, f"fold_{id}", "masks", "*"))
            test_img_paths.sort()
            test_mask_paths.sort()

            test_augprams = config["test"]["augment"]
            test_transform = Augmenter(**test_augprams)
            test_loader = get_loader(
                test_img_paths,
                test_mask_paths,
                transform=test_transform,
                **config["test"]["dataloader"],
                is_train=False,
            )
            test_size = len(test_loader)

            # MODEL

            logger.info("Loading model")
            model_prams = config["model"]
            import network.models as models

            arch = model_prams["arch"]

            # TRANSUNET
            n_skip = 3
            vit_name = "R50-ViT-B_16"
            vit_patches_size = 16
            img_size = config["dataset"]["img_size"]
            from network.models.transunet.vit_seg_modeling import (
                CONFIGS as CONFIGS_ViT_seg,
            )
            import numpy as np

            config_vit = CONFIGS_ViT_seg[vit_name]
            config_vit.n_classes = 1
            config_vit.n_skip = n_skip
            if vit_name.find("R50") != -1:
                config_vit.patches.grid = (
                    int(img_size / vit_patches_size),
                    int(img_size / vit_patches_size),
                )

            model = models.__dict__[arch](
                config_vit, img_size=img_size, num_classes=config_vit.n_classes
            )  # TransUnet

            # model = models.__dict__[arch]()  #Pranet

            model_path = os.path.join(
                model_prams["save_dir"],
                model_prams["arch"],
                f"PraNetDG-fold{id}-{e}.pth",
            )
            try:
                model.load_state_dict(torch.load(model_path)["model_state_dict"])
            except RuntimeError:
                model.load_state_dict(torch.load(model_path))
            model.cuda()
            model.eval()

            tp_all = 0
            fp_all = 0
            fn_all = 0

            mean_precision = 0
            mean_recall = 0
            mean_iou = 0
            mean_dice = 0

            test_fold = "fold" + str(config["dataset"]["fold"])
            logger.info(f"Start testing fold{id} epoch {e}")
            visualize_dir = "results"

            test_fold = "fold" + str(id)
            for i, pack in tqdm.tqdm(enumerate(test_loader, start=1)):
                image, gt, filename, img = pack
                name = os.path.splitext(filename[0])[0]
                ext = os.path.splitext(filename[0])[1]
                gt = gt[0][0]
                gt = np.asarray(gt, np.float32)
                res2 = 0
                image = image.cuda()

                # res5, res4, res3, res2 = model(image)
                res2 = model(image)

                res = res2
                res = F.upsample(
                    res, size=gt.shape, mode="bilinear", align_corners=False
                )
                res = res.sigmoid().data.cpu().numpy().squeeze()
                res = (res - res.min()) / (res.max() - res.min() + 1e-8)

                overwrite = config["test"]["vis_overwrite"]
                vis_x = config["test"]["vis_x"]
                if config["test"]["visualize"]:
                    save_img(
                        os.path.join(
                            visualize_dir,
                            test_fold,
                            str(arch),
                            name + "_pr" + str(arch) + ext,
                        ),
                        res.round() * 255,
                        "cv2",
                        overwrite,
                    )
                    save_img(
                        os.path.join(
                            visualize_dir,
                            test_fold,
                            "soft_" + str(arch),
                            name + "_soft_pr" + str(arch) + ext,
                        ),
                        res * 255,
                        "cv2",
                        overwrite,
                    )
                    # mask_img = np.asarray(img[0]) + cv2.cvtColor(res.round()*60, cv2.COLOR_GRAY2BGR)
                    mask_img = (
                        np.asarray(img[0])
                        + vis_x
                        * np.array(
                            (
                                np.zeros_like(res.round()),
                                res.round(),
                                np.zeros_like(res.round()),
                            )
                        ).transpose((1, 2, 0))
                        + vis_x
                        * np.array(
                            (gt, np.zeros_like(gt), np.zeros_like(gt))
                        ).transpose((1, 2, 0))
                    )
                    mask_img = mask_img[:, :, ::-1]
                    save_img(
                        os.path.join(
                            visualize_dir,
                            test_fold,
                            "mask_" + str(arch),
                            name + "mask_pr" + str(arch) + ext,
                        ),
                        mask_img,
                        "cv2",
                        overwrite,
                    )

                pr = res.round()

                prs.append(pr)
                gts.append(gt)
                tp = np.sum(gt * pr)
                fp = np.sum(pr) - tp
                fn = np.sum(gt) - tp
                tp_all += tp
                fp_all += fp
                fn_all += fn

                mean_precision += precision_m(gt, pr)
                mean_recall += recall_m(gt, pr)
                mean_iou += jaccard_m(gt, pr)
                mean_dice += dice_m(gt, pr)

            mean_precision /= len(test_loader)
            mean_recall /= len(test_loader)
            mean_iou /= len(test_loader)
            mean_dice /= len(test_loader)
            logger.info(
                "scores ver1: {:.3f} {:.3f} {:.3f} {:.3f}".format(
                    mean_iou, mean_precision, mean_recall, mean_dice
                )
            )

            precision_all = tp_all / (tp_all + fp_all + K.epsilon())
            recall_all = tp_all / (tp_all + fn_all + K.epsilon())
            dice_all = 2 * precision_all * recall_all / (precision_all + recall_all)
            iou_all = (
                recall_all
                * precision_all
                / (recall_all + precision_all - recall_all * precision_all)
            )
            logger.info(
                "scores ver2: {:.3f} {:.3f} {:.3f} {:.3f}".format(
                    iou_all, precision_all, recall_all, dice_all
                )
            )

    from utils.metrics import get_scores_v1, get_scores_v2

    if len(folds.keys()) > 1:
        get_scores_v1(gts, prs, logger)
        get_scores_v2(gts, prs, logger)

    return gts, prs
Ejemplo n.º 10
0
def main():
    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=False,
                        default="configs/gcpa_gald_net_config.yaml")
    args = parser.parse_args()

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)
    logger.add(
        f'logs/train_{config["model"]["arch"]}_{str(datetime.now())}_{config["dataset"]["fold"]}.log',
        rotation="10 MB",
    )

    logger.info(f"Load config from {config_path}")
    logger.info(f"{config}")

    # GET_DATA_PATH
    logger.info("Getting datapath")

    train_img_paths = []
    train_mask_paths = []
    train_softlabel_paths = []
    train_data_path = config["dataset"]["train_data_path"]
    for i in train_data_path:
        train_img_paths.extend(glob(os.path.join(i, "images", "*")))
        train_mask_paths.extend(glob(os.path.join(i, "masks", "*")))
    train_softlabel_paths.extend(
        glob(os.path.join(config["dataset"]["distill_label_path"], "*")))

    train_img_paths.sort()
    train_mask_paths.sort()
    train_softlabel_paths.sort()
    print(
        train_data_path,
        os.path.join(config["dataset"]["distill_label_path"], "*"),
        train_softlabel_paths[:5],
        train_mask_paths[:5],
        len(train_softlabel_paths),
        len(train_mask_paths),
    )
    # import sys

    # sys.exit()

    logger.info(f"There are {len(train_img_paths)} images to train")

    val_img_paths = []
    val_mask_paths = []
    val_data_path = config["dataset"]["val_data_path"]
    for i in val_data_path:
        val_img_paths.extend(glob(os.path.join(i, "images", "*")))
        val_mask_paths.extend(glob(os.path.join(i, "masks", "*")))
    val_img_paths.sort()
    val_mask_paths.sort()
    logger.info(f"There are {len(val_mask_paths)} images to val")

    # DATALOADER
    logger.info("Loading data")
    train_augprams = config["train"]["augment"]
    train_transform = Augmenter(**train_augprams)
    train_loader = get_loader(
        train_img_paths,
        train_mask_paths,
        softlabel_paths=train_softlabel_paths,
        transform=train_transform,
        **config["train"]["dataloader"],
        type="train",
    )
    logger.info(f"{len(train_loader)} batches to train")

    val_augprams = config["test"]["augment"]
    val_transform = Augmenter(**val_augprams)
    val_loader = get_loader(
        val_img_paths,
        val_mask_paths,
        transform=val_transform,
        **config["test"]["dataloader"],
        type="val",
    )
    # USE MODEL
    logger.info("Loading model")
    model_prams = config["model"]
    if "save_dir" not in model_prams:
        save_dir = os.path.join("snapshots", model_prams["arch"] + "_kfold")
    else:
        save_dir = config["model"]["save_dir"]

    import network.models as models

    model = models.__dict__[model_prams["arch"]]()  # Pranet
    model = model.cuda()

    model1 = models.__dict__[model_prams["arch_teacher"]]()
    model1 = model1.cuda()
    model1.load_state_dict(torch.load(
        model_prams["weight_teacher_path"])["model_state_dict"],
                           strict=False)

    if model_prams["start_from"] != 0:
        restore_from = os.path.join(
            save_dir,
            f'PraNetDG-fold{config["dataset"]["fold"]}-{model_prams["start_from"]}.pth',
        )
        # lr = model.initialize_weights(restore_from)
        saved_state_dict = torch.load(restore_from)["model_state_dict"]
        lr = torch.load(restore_from)["lr"]
        model.load_state_dict(saved_state_dict, strict=False)

    params = model.parameters()

    # USE OPTIMIZER
    opt_params = config["optimizer"]
    import network.optim.optimizers as optims

    lr = opt_params["lr"]
    optimizer = optims.__dict__[opt_params["name"].lower()](params, lr / 8)

    # USE SCHEDULE
    import network.optim.schedulers as schedulers

    scheduler = schedulers.__dict__[opt_params["scheduler"]](
        optimizer, model_prams["num_epochs"], opt_params["num_warmup_epoch"])
    # USE LOSS
    import network.optim.losses as losses

    loss = losses.__dict__[opt_params["loss"]]()

    # TRAINER
    fold = config["dataset"]["fold"]
    logger.info("#" * 20 + f"Start Training Fold {fold}" + "#" * 20)
    from network.models import (
        Trainer,
        TrainerDistillation,
    )

    #  TrainerSCWS

    trainer = TrainerDistillation(model, model1, optimizer, loss, scheduler,
                                  save_dir, model_prams["save_from"], logger)

    trainer.fit(
        train_loader=train_loader,
        is_val=config["train"]["is_val"],
        val_loader=val_loader,
        img_size=config["train"]["dataloader"]["img_size"],
        start_from=model_prams["start_from"],
        num_epochs=model_prams["num_epochs"],
        batchsize=config["train"]["dataloader"]["batchsize"],
        fold=fold,
        size_rates=config["train"]["size_rates"],
    )
Ejemplo n.º 11
0
def main():

    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=True,
                        default="configs/default_config.yaml")
    args = parser.parse_args()

    logger.add(f'logs/{str(datetime.now())}_test_log_file.log',
               rotation="10 MB")

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)

    gts = []
    prs = []

    folds = config["test"]["folds"]

    for id in list(folds.keys()):

        test_img_paths = []
        test_mask_paths = []
        data_path = config["dataset"]["data_path"]
        test_img_paths = glob(
            os.path.join(data_path, f'fold_{id}', "images", "*"))
        test_mask_paths = glob(
            os.path.join(data_path, f'fold_{id}', "masks", "*"))
        test_img_paths.sort()
        test_mask_paths.sort()

        test_augprams = config["test"]["augment"]
        test_transform = Augmenter(**test_augprams)
        test_loader = get_loader(test_img_paths,
                                 test_mask_paths,
                                 transform=test_transform,
                                 **config["test"]["dataloader"])
        test_size = len(test_loader)

        # MODEL

        logger.info("Loading model")
        model_prams = config["model"]
        import network.models as models
        arch = model_prams["arch"]
        model = models.__dict__[arch]()
        model_path = os.path.join(
            model_prams["save_dir"], model_prams["arch"],
            f'PraNetDG-fold{config["dataset"]["fold"]}-{config["test"]["folds"][id]}.pth'
        )
        try:
            model.load_state_dict(torch.load(model_path)["model_state_dict"])
        except RuntimeError:
            model.load_state_dict(torch.load(model_path))
        model.cuda()
        model.eval()

        tp_all = 0
        fp_all = 0
        fn_all = 0

        mean_precision = 0
        mean_recall = 0
        mean_iou = 0
        mean_dice = 0

        test_fold = "fold" + str(config["dataset"]["fold"])
        logger.info("Start testing")
        visualize_dir = "results"

        test_fold = 'fold' + str(id)
        for i, pack in tqdm.tqdm(enumerate(test_loader, start=1)):
            image, gt, filename, img = pack
            name = os.path.splitext(filename[0])[0]
            ext = os.path.splitext(filename[0])[1]
            gt = gt[0][0]
            gt = np.asarray(gt, np.float32)
            res2 = 0
            image = image.cuda()

            res5, res4, res3, res2 = model(image)

            res = res2
            res = F.upsample(res,
                             size=gt.shape,
                             mode='bilinear',
                             align_corners=False)
            res = res.sigmoid().data.cpu().numpy().squeeze()
            res = (res - res.min()) / (res.max() - res.min() + 1e-8)

            overwrite = config["test"]["vis_overwrite"]
            vis_x = config["test"]["vis_x"]
            if (config["test"]["visualize"]):
                save_img(
                    os.path.join(visualize_dir, test_fold, str(arch),
                                 name + "_pr" + str(arch) + ext),
                    res.round() * 255, "cv2", overwrite)
                save_img(
                    os.path.join(visualize_dir, test_fold, "soft_" + str(arch),
                                 name + "_soft_pr" + str(arch) + ext),
                    res * 255, "cv2", overwrite)
                # mask_img = np.asarray(img[0]) + cv2.cvtColor(res.round()*60, cv2.COLOR_GRAY2BGR)
                mask_img = np.asarray(img[0]) + vis_x * np.array(
                    (np.zeros_like(res.round()), res.round(),
                     np.zeros_like(res.round()))).transpose(
                         (1, 2, 0)) + vis_x * np.array(
                             (gt, np.zeros_like(gt),
                              np.zeros_like(gt))).transpose((1, 2, 0))
                mask_img = mask_img[:, :, ::-1]
                save_img(
                    os.path.join(visualize_dir, test_fold, "mask_" + str(arch),
                                 name + "mask_pr" + str(arch) + ext), mask_img,
                    "cv2", overwrite)

            pr = res.round()

            prs.append(pr)
            gts.append(gt)
            tp = np.sum(gt * pr)
            fp = np.sum(pr) - tp
            fn = np.sum(gt) - tp
            tp_all += tp
            fp_all += fp
            fn_all += fn

            mean_precision += precision_m(gt, pr)
            mean_recall += recall_m(gt, pr)
            mean_iou += jaccard_m(gt, pr)
            mean_dice += dice_m(gt, pr)

        mean_precision /= len(test_loader)
        mean_recall /= len(test_loader)
        mean_iou /= len(test_loader)
        mean_dice /= len(test_loader)
        logger.info("scores ver1: {:.3f} {:.3f} {:.3f} {:.3f}".format(
            mean_iou, mean_precision, mean_recall, mean_dice))

        precision_all = tp_all / (tp_all + fp_all + K.epsilon())
        recall_all = tp_all / (tp_all + fn_all + K.epsilon())
        dice_all = 2 * precision_all * recall_all / (precision_all +
                                                     recall_all)
        iou_all = recall_all * precision_all / (recall_all + precision_all -
                                                recall_all * precision_all)
        logger.info("scores ver2: {:.3f} {:.3f} {:.3f} {:.3f}".format(
            iou_all, precision_all, recall_all, dice_all))

    return gts, prs
Ejemplo n.º 12
0
def main():

    parser = ArgumentParser()
    parser.add_argument("-c",
                        "--config",
                        required=True,
                        default="configs/default_config.yaml")
    args = parser.parse_args()

    logger.info("Loading config")
    config_path = args.config
    config = load_cfg(config_path)

    gts = []
    prs = []

    folds = config["test"]["folds"]
    print(folds)
    dataset = config["dataset"]["test_data_path"][0].split("/")[-1]
    if len(folds.keys()) == 1:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_{list(folds.keys())[0]}_{dataset}.log',
            rotation="10 MB",
        )
    else:
        logger.add(
            f'logs/test_{config["model"]["arch"]}_{str(datetime.now())}_kfold.log',
            rotation="10 MB",
        )

    for id in list(folds.keys()):

        # FOR ORIDATASET
        test_img_paths = []
        test_mask_paths = []
        test_data_path = config["dataset"]["test_data_path"]
        for i in test_data_path:
            test_img_paths.extend(glob(os.path.join(i, "images", "*")))
            test_mask_paths.extend(glob(os.path.join(i, "masks", "*")))
        test_img_paths.sort()
        test_mask_paths.sort()

        test_augprams = config["test"]["augment"]
        test_transform = Augmenter(**test_augprams)
        test_loader = get_loader(
            test_img_paths,
            test_mask_paths,
            transform=test_transform,
            **config["test"]["dataloader"],
            type="test",
        )
        test_size = len(test_loader)

        logger.info("Loading model")
        model_prams = config["model"]
        import network.models as models

        arch = model_prams["arch"]

        model = models.__dict__[arch]()  # Pranet

        model_path = "pretrained/PraNet-19.pth"

        device = torch.device("cpu")
        # model.cpu()

        model.cuda()
        model.eval()

        logger.info(f"Loading from {model_path}")
        try:

            model.load_state_dict(torch.load(model_path))
        except RuntimeError:
            model.load_state_dict(torch.load(model_path))

        tp_all = 0
        fp_all = 0
        fn_all = 0

        mean_precision = 0
        mean_recall = 0
        mean_iou = 0
        mean_dice = 0
        mean_F2 = 0
        mean_acc = 0

        mean_precision_np = 0
        mean_recall_np = 0
        mean_iou_np = 0
        mean_dice_np = 0

        test_fold = "fold" + str(config["dataset"]["fold"])
        logger.info(f"Start testing fold{id} epoch {e}")
        if "visualize_dir" not in config["test"]:
            visualize_dir = "results"
        else:
            visualize_dir = config["test"]["visualize_dir"]

        test_fold = "fold" + str(id)
        logger.info(
            f"Start testing {len(test_loader)} images in {dataset} dataset")

        for i, pack in tqdm.tqdm(enumerate(test_loader, start=1)):
            image, gt, filename, img = pack
            name = os.path.splitext(filename[0])[0]
            ext = os.path.splitext(filename[0])[1]
            gt = gt[0][0]
            gt = np.asarray(gt, np.float32)
            gt /= gt.max() + 1e-8
            res2 = 0
            image = image.cuda()
            # image = image.cpu()

            res5, res4, res3, res2 = model(image)
            # _, _, res5, res4, res3, res2 = model(image)
            # res5_head, res5, res4, res3, res2 = model(image)
            # res2 = model(image)

            res = res2
            res = F.upsample(res,
                             size=gt.shape,
                             mode="bilinear",
                             align_corners=False)
            res = res.sigmoid().data.cpu().numpy().squeeze()
            res = (res - res.min()) / (res.max() - res.min() + 1e-8)

            overwrite = config["test"]["vis_overwrite"]
            vis_x = config["test"]["vis_x"]
            if config["test"]["visualize"]:
                save_img(
                    os.path.join(
                        visualize_dir,
                        test_fold,
                        str(arch),
                        name + "_pr" + str(arch) + ext,
                    ),
                    res.round() * 255,
                    "cv2",
                    overwrite,
                )
                save_img(
                    os.path.join(
                        visualize_dir,
                        test_fold,
                        "soft_" + str(arch),
                        name + "_soft_pr" + str(arch) + ext,
                    ),
                    res * 255,
                    "cv2",
                    overwrite,
                )
                # mask_img = np.asarray(img[0]) + cv2.cvtColor(res.round()*60, cv2.COLOR_GRAY2BGR)
                mask_img = (np.asarray(img[0]) + vis_x * np.array((
                    np.zeros_like(res.round()),
                    res.round(),
                    np.zeros_like(res.round()),
                )).transpose((1, 2, 0)) + vis_x * np.array(
                    (gt, np.zeros_like(gt), np.zeros_like(gt))).transpose(
                        (1, 2, 0)))
                mask_img = mask_img[:, :, ::-1]
                save_img(
                    os.path.join(
                        visualize_dir,
                        test_fold,
                        "mask_" + str(arch),
                        name + "mask_pr" + str(arch) + ext,
                    ),
                    mask_img,
                    "cv2",
                    overwrite,
                )

            pr = res.round()
            prs.append(pr)
            gts.append(gt)

            tp = np.sum(gt * pr)
            fp = np.sum(pr) - tp
            fn = np.sum(gt) - tp
            tp_all += tp
            fp_all += fp
            fn_all += fn

            mean_precision += precision_m(gt, pr)
            mean_recall += recall_m(gt, pr)
            mean_iou += jaccard_m(gt, pr)
            mean_dice += dice_m(gt, pr)
            mean_F2 += (5 * precision_m(gt, pr) * recall_m(gt, pr)) / (
                4 * precision_m(gt, pr) + recall_m(gt, pr))
            # mean_acc += (tp+tn)/(tp+tn+fp+fn)

            pr = res
            thresh_precision = 0
            thresh_recall = 0
            thresh_iou = 0
            thresh_dice = 0
            for thresh in np.arange(0, 1, 1 / 256):
                out = pr.copy()
                out[out < thresh] = 0
                out[out >= thresh] = 1
                thresh_precision += precision_m(gt, out)
                thresh_recall += recall_m(gt, out)
                thresh_iou += jaccard_m(gt, out)
                thresh_dice += dice_m(gt, out)

            mean_precision_np += thresh_precision / 256
            mean_recall_np += thresh_recall / 256
            mean_iou_np += thresh_iou / 256
            mean_dice_np += thresh_dice / 256

        mean_precision_np /= len(test_loader)
        mean_recall_np /= len(test_loader)
        mean_iou_np /= len(test_loader)
        mean_dice_np /= len(test_loader)

        logger.info("scores ver0: {:.3f} {:.3f} {:.3f} {:.3f}".format(
            mean_iou_np, mean_precision_np, mean_recall_np, mean_dice_np
            # , mean_F2
        ))

        mean_precision /= len(test_loader)
        mean_recall /= len(test_loader)
        mean_iou /= len(test_loader)
        mean_dice /= len(test_loader)
        mean_F2 /= len(test_loader)

        logger.info("scores ver1: {:.3f} {:.3f} {:.3f} {:.3f}".format(
            mean_iou, mean_precision, mean_recall, mean_dice
            # , mean_F2
        ))

        # logger.info(
        #     "scores ver1: {:.3f} {:.3f} {:.3f} {:.3f}".format(
        #         mean_iou,
        #         mean_precision,
        #         mean_recall,
        #         mean_dice
        #         # , mean_F2
        #     )
        # )

        precision_all = tp_all / (tp_all + fp_all + 1e-07)
        recall_all = tp_all / (tp_all + fn_all + 1e-07)
        dice_all = 2 * precision_all * recall_all / (precision_all +
                                                     recall_all)
        iou_all = (recall_all * precision_all /
                   (recall_all + precision_all - recall_all * precision_all))
        logger.info("scores ver2: {:.3f} {:.3f} {:.3f} {:.3f}".format(
            iou_all, precision_all, recall_all, dice_all))

    from utils.metrics import get_scores_v1, get_scores_v2

    if len(folds.keys()) > 1:
        get_scores_v1(gts, prs, logger)
        get_scores_v2(gts, prs, logger)

    return gts, prs