Exemple #1
0
def do_glas():
    """
    GlaS.

    :return:
    """
    # ===============
    # Reproducibility
    # ===============
    reproducibility.init_seed()

    announce_msg("Processing dataset: {}".format(constants.GLAS))

    args = {
        "baseurl": get_rootpath_2_dataset(Dict2Obj({'dataset':
                                                    constants.GLAS})),
        "folding": {
            "vl": 20
        },  # 80 % for train, 20% for validation.
        "dataset": "glas",
        "fold_folder": "folds/glas",
        "img_extension": "bmp",
        # nbr_splits: how many times to perform the k-folds over
        # the available train samples.
        "nbr_splits": 1
    }
    args["nbr_folds"] = math.ceil(100. / args["folding"]["vl"])

    reproducibility.init_seed()
    al_split_glas(Dict2Obj(args))
    get_stats(Dict2Obj(args), split=0, fold=0, subset='train')
Exemple #2
0
def do_Oxford_flowers_102():
    """
    Oxford-flowers-102.
    The train/valid/test sets are already provided.

    :return:
    """
    # ===============
    # Reproducibility
    # ===============

    # ===========================

    reproducibility.init_seed()

    # ===========================

    announce_msg("Processing dataset: {}".format(constants.OXF))
    args = {
        "baseurl": get_rootpath_2_dataset(Dict2Obj({'dataset':
                                                    constants.OXF})),
        "dataset": "Oxford-flowers-102",
        "fold_folder": "folds/Oxford-flowers-102",
        "img_extension": "jpg",
        "path_encoding": "folds/Oxford-flowers-102/encoding-origine.yaml"
    }
    # Convert masks into binary masks: already done.
    # create_bin_mask_Oxford_flowers_102(Dict2Obj(args))
    reproducibility.init_seed()
    al_split_Oxford_flowers_102(Dict2Obj(args))
    get_stats(Dict2Obj(args), split=0, fold=0, subset='train')
Exemple #3
0
def do_camelyon16():
    """
    camelyon16.
    The train/valid/test sets are already provided.

    :return:
    """
    # ===============
    # Reproducibility
    # ===============

    # ===========================

    reproducibility.init_seed()

    # ===========================

    ds = constants.CAM16
    announce_msg("Processing dataset: {}".format(ds))
    args = {
        "baseurl": get_rootpath_2_dataset(Dict2Obj({'dataset': ds})),
        "dataset": ds,
        "fold_folder": "folds/{}".format(ds),
        "img_extension": "jpg",
        "path_encoding": "folds/{}/encoding-origine.yaml".format(ds)
    }
    # Convert masks into binary masks: already done.
    # create_bin_mask_Oxford_flowers_102(Dict2Obj(args))
    reproducibility.init_seed()
    al_split_camelyon16(Dict2Obj(args))
def test__LossExtendedLB():
    force_seed(0, check_cudnn=False)
    instance = _LossExtendedLB(init_t=1., max_t=10., mulcoef=1.01)
    announce_msg("Testing {}".format(instance))

    cuda = 1
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        torch.cuda.set_device(int(cuda))
    instance.to(DEVICE)

    b = 16
    fx = (torch.rand(b)).to(DEVICE)

    out = instance(fx)
    for r in range(10):
        instance.update_t()
        print("epoch {}. t: {}.".format(r, instance.t_lb))
    print("Loss ELB.sum(): {}".format(out))
Exemple #5
0
def get_stats(args, split, fold, subset):
    """
    Get some stats on the image sizes of specific dataset, split, fold.
    """
    if not os.path.isdir(args.fold_folder):
        os.makedirs(args.fold_folder)

    tag = "ds-{}-s-{}-f-{}-subset-{}".format(args.dataset, split, fold, subset)
    log = open(join(args.fold_folder, "log-stats-ds-{}.txt".format(tag)), 'w')
    announce_msg("Going to check {}".format(args.dataset.upper()))

    relative_fold_path = join(args.fold_folder, "split_{}".format(split),
                              "fold_{}".format(fold))

    subset_csv = join(relative_fold_path,
                      "{}_s_{}_f_{}.csv".format(subset, split, fold))
    rootpath = get_rootpath_2_dataset(args)
    samples = csv_loader(subset_csv, rootpath)

    lh, lw = [], []
    for el in samples:
        img = Image.open(el[1], 'r').convert('RGB')
        w, h = img.size
        lh.append(h)
        lw.append(w)

    msg = "min h {}, \t max h {}".format(min(lh), max(lh))
    show_msg(msg, log)
    msg = "min w {}, \t max w {}".format(min(lw), max(lw))
    show_msg(msg, log)

    fig, axes = plt.subplots(nrows=1, ncols=2)
    axes[0].hist(lh)
    axes[0].set_title('Heights')
    axes[1].hist(lw)
    axes[1].set_title('Widths')
    fig.tight_layout()
    plt.savefig(join(args.fold_folder, "size-stats-{}.png".format(tag)))

    log.close()
def pair_samples(args,
                 train_samples,
                 tr_leftovers,
                 SIMS,
                 previous_pairs,
                 fd_p_msks
                 ):
    """
    Pair samples.
    :return:
    """
    VERBOSE = False

    acc_new_samples = 0.
    nbrx = 0
    train_samples_before_merging = deepcopy(train_samples)
    pairs = dict()
    metrics_fd = join(fd_p_msks, "metrics")  # inside the fd of the masks.

    if args.al_type == constants.AL_LP:
        set_default_seed()
        pairmaker = PairSamples(task=args.task,
                                knn=args.knn
                                )
        set_default_seed()
        announce_msg("starts pairing...")
        pairs = pairmaker(train_samples, tr_leftovers, SIMS)
        announce_msg("finishes pairing...")
        set_default_seed()

        # if pair exists in the previous round, delete it.
        # pairs must contain only newly paired samples or samples that have
        # been paired but changed their source.

        for k in pairs.keys():
            if (k, pairs[k]) in previous_pairs:  # same pair exist.
                pairs.pop(k)

    return pairs, acc_new_samples, nbrx,train_samples_before_merging
Exemple #7
0
def do_Caltech_UCSD_Birds_200_2011():
    """
    Caltech-UCSD-Birds-200-2011.

    :return:
    """
    # ===============
    # Reproducibility
    # ===============

    # ===========================

    reproducibility.init_seed()

    # ===========================

    announce_msg("Processing dataset: {}".format(constants.CUB))

    args = {
        "baseurl": get_rootpath_2_dataset(Dict2Obj({'dataset':
                                                    constants.CUB})),
        "folding": {
            "vl": 20
        },  # 80 % for train, 20% for validation.
        "dataset": "Caltech-UCSD-Birds-200-2011",
        "fold_folder": "folds/Caltech-UCSD-Birds-200-2011",
        "img_extension": "bmp",
        "nbr_splits": 1,  # how many times to perform the k-folds over
        # the available train samples.
        "path_encoding":
        "folds/Caltech-UCSD-Birds-200-2011/encoding-origine.yaml",
        "nbr_classes": None  # Keep only 5 random classes. If you want
        # to use the entire dataset, set this to None.
    }
    args["nbr_folds"] = math.ceil(100. / args["folding"]["vl"])
    reproducibility.init_seed()
    al_split_Caltech_UCSD_Birds_200_2011(Dict2Obj(args))
    get_stats(Dict2Obj(args), split=0, fold=0, subset='train')
Exemple #8
0
def show_msg(ms, lg):
    announce_msg(ms)
    lg.write(ms + "\n")
Exemple #9
0
def test_HybridModel():
    """
    Test `HybridModel`.
    :return:
    """
    cuda = "6"
    print("cuda:{}".format(cuda))
    print("DEVICE BEFORE: ", torch.cuda.current_device())
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        # torch.cuda.set_device(int(cuda))
        pass

    print("DEVICE AFTER: ", torch.cuda.current_device())

    num_classes = 10
    num_masks = 1
    batch = 3
    x = torch.randn(batch, 3, 256, 253)
    x = x.to(DEVICE)
    for backbone_dropout in [0., 0.1]:
        for pretrained in [False, True]:
            for freeze_cl in [False, True]:
                for backbone in constants.backbones:
                    model = hybrid_model(num_classes,
                                         num_masks=num_masks,
                                         backbone=backbone,
                                         pretrained=pretrained,
                                         modalities=4,
                                         kmax=0.5,
                                         kmin=None,
                                         alpha=0.6,
                                         dropout=0.0,
                                         backbone_dropout=backbone_dropout,
                                         freeze_classifier=freeze_cl)
                    model.eval()
                    nbr_params_total = model.get_nbr_params()
                    cl_params = model.classifier.get_nbr_params()
                    announce_msg("{} (pret: {}). "
                                 "Total-params: {}. "
                                 "CL-params: {} ({:.2f}%) ".format(
                                     backbone,
                                     pretrained,
                                     nbr_params_total,
                                     cl_params,
                                     100. * cl_params /
                                     float(nbr_params_total),
                                 ))
                    # DEVICE = torch.device("cpu")
                    model.to(DEVICE)
                    t0 = dt.datetime.now()
                    scores, masks, maps = model(x)
                    print("Time forward {} of {} samples".format(
                        dt.datetime.now() - t0, batch))
                    print("in: ", x.shape, "scores: ", scores.shape, "masks: ",
                          masks.shape, "maps: ", maps.shape)
                    print("Model: {}".format(model))
                    print("Min-max output-masks: {}, {}".format(
                        masks.min(), masks.max()))
                    msg = "h, w mismatch: x {}  masks {}".format(
                        x.shape[2:], masks[2:])
                    assert x.shape[2:] == masks.shape[2:], msg

                    if backbone_dropout != 0.:
                        model.set_dropout_to_train_mode()
                    print("backbone dropout: {}".format(backbone_dropout))
                    scores, masks, maps = model(x)
                    print("Scores first forward:")
                    print(scores)
                    scores, masks, maps = model(x)
                    print("Scores second forward:")
                    print(scores)

                    # test freezing the classifier
                    if freeze_cl:
                        print("going to freeze the classifier:")
                        model.freeze_cl()
                        scores, masks, maps = model(x)
                        print("Scores after freezing the classifier:")
                        print(scores)
def validate(model,
             dataset,
             dataloader,
             criterion,
             device,
             stats,
             args,
             epoch=0,
             cycle=0,
             log_file=None
             ):
    """
    Perform a validation over a set.
    Dataset passed here must be in `validation` mode.
    Task: SEG.
    """
    set_default_seed()

    model.eval()
    criterion.eval()
    metrics = Metrics(threshold=args.seg_threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    # to track stats.
    keys = ["acc", "dice_idx"]
    losses_kz = ["total_loss", "cl_loss", "seg_l_loss", "seg_lp_loss"]
    keys = keys + losses_kz
    tracker = dict()
    for k in keys:
        tracker[k] = 0.

    n_samples = 0.
    n_sam_dice = 0.

    # ignoring samples is on only: 1. cam16 dataset. 2. wsl method.
    # for the rest of the methods, normal samples are completely dropped before
    # starting training.
    cnd_dice = (args.dataset == constants.CAM16)
    cnd_dice &= (args.al_type == constants.AL_WSL)

    t0 = dt.datetime.now()

    with torch.no_grad():
        for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            targets = None  # not needed for SEG task.
            masks_trg = None
            bsz = data.size()[0]

            data = data.to(device)
            labels = label.to(device)
            if mask is not None:
                masks_trg = mask.to(device)

            scores, masks_pred, maps = model(x=data, seed=None)

            check_nans(maps, "fresh-maps")

            # change to the predicted mask to be the maps if WSL.
            if args.al_type == constants.AL_WSL:
                lb_ = labels.view(-1, 1, 1, 1)
                masks_pred = (maps.argmax(dim=1, keepdim=True) == lb_).float()


            # resize pred mask if sizes mismatch.
            if masks_pred.shape != masks_trg.shape:
                _, _, oh, ow = masks_trg.shape
                masks_pred = dataset.turnback_mask_tensor_into_original_size(
                    pred_masks=masks_pred, oh=oh, ow=ow
                )
                check_nans(masks_pred, "mask-pred-back-to-normal-size")

            if args.al_type != constants.AL_WSL:
                # TODO: change normalization location. (future)
                masks_pred = torch.sigmoid(masks_pred)  # binary segmentation

            losses = criterion(scores=scores,
                               labels=labels,
                               targets=targets,
                               masks_pred=masks_pred.view(bsz, -1),
                               masks_trg=masks_trg.view(bsz, -1),
                               tags=tag,
                               weights=None,
                               avg=False
                               )

            # metrics
            ignore_dice = None
            if cnd_dice:
                # ignore samples with label 'normal' (0) when computing dice.
                ignore_dice = (labels == 0).float().view(-1)

            metrx = metrics(scores=scores,
                            labels=labels,
                            tr_loss=criterion,
                            masks_pred=masks_pred.view(bsz, -1),
                            masks_trg=masks_trg.view(bsz, -1),
                            avg=False,
                            ignore_dice=ignore_dice
                            )

            tracker["acc"] += metrx[0].item()
            tracker["dice_idx"] += metrx[1].item()

            # print("Dice valid i: {} {}. datashape {} id {}".format(
            #     i, metrx[1].item() * 100., data.shape, ids))


            for j, los in enumerate(losses):
                tracker[losses_kz[j]] += los.item()

            n_samples += bsz  # nbr samples.

            if cnd_dice:
                n_sam_dice += ignore_dice.numel() - ignore_dice.sum()
            else:
                n_sam_dice += bsz

            # clear the memory
            del scores
            del masks_pred
            del maps

    # average.
    for k in keys:
        if (k == 'dice_idx') and cnd_dice:
            announce_msg("Ignore normal samples mode. "
                         "Total samples left: {}/Total: {}.".format(
                n_sam_dice, n_samples))
            if n_sam_dice == 0:
                tracker[k] = 0.
            else:
                tracker[k] /= float(n_sam_dice)
        else:
            tracker[k] /= float(n_samples)

    to_write = "VL {:>2d}-{:>2d}: ACC: {:.4f}, DICE: {:.4f}, time:{}".format(
                cycle, epoch, tracker["acc"] * 100., tracker["dice_idx"] * 100.,
                dt.datetime.now() - t0
                 )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # Update stats.
    if stats is not None:
        for k in keys:
            stats[k].append(tracker[k])
    else:
        # convert each element into a list
        for k in keys:
            tracker[k] = [tracker[k]]
        stats = deepcopy(tracker)

    set_default_seed()

    return stats
sys.path.append("..")
sys.path.append("../..")

from shared import check_if_allow_multgpu_mode, announce_msg
from deeplearning.utils import initialize_weights

ACTIVATE_SYNC_BN = False
# Override ACTIVATE_SYNC_BN using variable environment in Bash:
# $ export ACTIVATE_SYNC_BN="True"   ----> Activate
# $ export ACTIVATE_SYNC_BN="False"   ----> Deactivate

if "ACTIVATE_SYNC_BN" in os.environ.keys():
    ACTIVATE_SYNC_BN = (os.environ['ACTIVATE_SYNC_BN'] == "True")

announce_msg("ACTIVATE_SYNC_BN was set to {}".format(ACTIVATE_SYNC_BN))

if check_if_allow_multgpu_mode() and ACTIVATE_SYNC_BN:  # Activate Synch-BN.
    from deeplearning.syncbn import nn as NN_Sync_BN
    BatchNorm2d = NN_Sync_BN.BatchNorm2d
    announce_msg("Synchronized BN has been activated. \n"
                 "MultiGPU mode has been activated. "
                 "{} GPUs".format(torch.cuda.device_count()))
else:
    BatchNorm2d = nn.BatchNorm2d
    if check_if_allow_multgpu_mode():
        announce_msg("Synchronized BN has been deactivated.\n"
                     "MultiGPU mode has been activated. "
                     "{} GPUs".format(torch.cuda.device_count()))
    else:
        announce_msg("Synchronized BN has been deactivated.\n"
def estimate_best_seg_thres(model,
                            dataset,
                            dataloader,
                            criterion,
                            device,
                            args,
                            epoch=0,
                            cycle=0,
                            log_file=None
                            ):
    """
    Perform a validation over a set.
    Allows estimating a segmentation threshold based on this evaluation.

    This is intended fo the `validationset` where we can estimate the best
    threshold since the samples are pixel-wise labeled.

    We specify a set of theresholds, and pick the one that has the best mean
    IOU score. MIOU is better then Dice index.

    Dataset passed here must be in `validation` mode.
    Task: SEG.
    """
    msg = "Can't use/estimate threshold over AL_WSL. masks are already binary."
    assert args.al_type != constants.AL_WSL, msg

    set_default_seed()

    announce_msg("Estimating threshold on validation set.")
    set_default_seed()

    model.eval()
    # the specified threshold here does not matter. we will use a different
    # ones later.
    metrics = Metrics(threshold=args.seg_threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    l_thress = np.arange(start=0.05, stop=1., step=0.01)
    # to track stats.
    avg_dice = np.zeros(l_thress.shape)
    avg_miou = np.zeros(l_thress.shape)

    nbr_ths = l_thress.size

    n_samples = 0.
    n_sam_dice = 0.
    # ignoring samples is on only: 1. cam16 dataset. 2. wsl method.
    # for the rest of the methods, normal samples are completely dropped before
    # starting training.
    cnd_dice = (args.dataset == constants.CAM16)
    cnd_dice &= (args.al_type == constants.AL_WSL)

    t0 = dt.datetime.now()

    with torch.no_grad():
        for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            targets = None  # not needed for SEG task.
            masks_trg = None
            bsz = data.size()[0]
            # assert bsz == 1, "batch size must be 1. found {}. ".format(bsz)

            data = data.to(device)
            labels = label.to(device)
            if mask is not None:
                masks_trg = mask.to(device)

            scores, masks_pred, maps = model(x=data, seed=None)

            # resize pred mask if sizes mismatch.
            if masks_pred.shape != masks_trg.shape:
                _, _, oh, ow = masks_trg.shape
                masks_pred = dataset.turnback_mask_tensor_into_original_size(
                    pred_masks=masks_pred, oh=oh, ow=ow
                )
            # TODO: change normalization location. (future)
            masks_pred = torch.sigmoid(masks_pred)

            ignore_dice = None
            if cnd_dice:
                # ignore samples with label 'normal' (0) when computing dice.
                ignore_dice = (labels == 0).float().view(-1)

            for ii, tt in enumerate(l_thress):
                metrx = metrics(scores=scores,
                                labels=labels,
                                tr_loss=criterion,
                                masks_pred=masks_pred.view(bsz, -1),
                                masks_trg=masks_trg.view(bsz, -1),
                                avg=False,
                                threshold=tt,
                                ignore_dice=ignore_dice
                                )

                avg_dice[ii] += metrx[1].item()
                avg_miou[ii] += metrx[2].item()

            n_samples += bsz  # nbr samples.

            if cnd_dice:
                n_sam_dice += ignore_dice.numel() - ignore_dice.sum()
            else:
                n_sam_dice += bsz

            # clear the memory
            del scores
            del masks_pred

    idx = avg_miou.argmax()
    best_threshold = l_thress[idx]
    if cnd_dice and (n_sam_dice != 0):
        best_dice = avg_dice[idx] / float(n_sam_dice)
        best_miou = avg_miou[idx] / float(n_sam_dice)
    else:
        best_dice = avg_dice[idx] / float(n_samples)
        best_miou = avg_miou[idx] / float(n_samples)

    out = {
        "best_threshold": best_threshold,
        "best_dice": best_dice,
        "best_miou": best_miou
    }

    to_write = "VL [ESTIM-THRESH] {:>2d}-{:>2d}. BEST-DICE: {:.2f}, " \
               "BEST-MIOU: {:.2f}. " \
               "time:{}".format(cycle,
                                epoch,
                                best_dice * 100.,
                                best_miou * 100.,
                                dt.datetime.now() - t0
    )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    to_write =  "best threshold: {:.3f}. " \
                 "BEST-DICE-VL: {:.2f}%, " \
                 "BEST-MIOU-VL: {:.2f}%.".format(
        best_threshold, best_dice * 100., best_miou * 100.
    )
    announce_msg(to_write)
    if log_file:
        log(log_file, to_write)

    set_default_seed()

    return out
    def __call__(self, data, args, device, outd, label=None):
        """
        Compute the pairwise distance.

        :param data: list of str-path to samples. The representation of each
        sample will be computed using a projector. its config. is in args.
        :param args: object containing the the main file input arguments.
        :param device: device on where the computation will take place.
        :param outd path where we write the similarities.
        :param label: str or None. used only in the case of self.task is SEG.
        """
        already_done = False
        if already_done:  # if we already computed this sim. nothing to do.
            return 0

        histc = None
        epsilon = 1e-8
        if args.use_dist_global_hist:
            histc = SoftHistogram(bins=args.nbr_bins_histc,
                                  min=args.min_histc,
                                  max=args.max_histc,
                                  sigma=args.sigma_histc).to(device)
        # dataloader
        transform_tensor = get_transforms_tensor(args)
        set_default_seed()
        dataset = PhotoDataset(
            data,
            args.dataset,
            args.name_classes,
            transforms.Compose([transforms.ToTensor()]),
            set_for_eval=False,
            transform_img=None,
            resize=None,
            crop_size=None,
            padding_size=(None, None),
            padding_mode=args.padding_mode,
            up_scale_small_dim_to=None,
            do_not_save_samples=True,
            ratio_scale_patch=1.,
            for_eval_flag=True,
            scale_algo=args.scale_algo,
            resize_h_to=args.resize_h_to_opt_mask,
            resize_mask=args.resize_mask_opt_mask,  # not important.
            enhance_color=args.enhance_color,
            enhance_color_fact=args.enhance_color_fact)
        set_default_seed()
        data_loader = DataLoader(dataset,
                                 batch_size=args.pair_w_batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 collate_fn=default_collate,
                                 worker_init_fn=_init_fn)
        set_default_seed()

        gaussian_smoother = None

        # loop! on GPU.
        nbr_samples = len(dataset)
        nbr_batches = len(data_loader)
        acc_label_prop = 0.
        z = 0.
        # project  all data and store them on disc in batches.
        idss = []
        labelss = []
        list_projections = []
        tag = ""
        # for the task SEG, the tag is helpful to avoid mixing files.
        if self.task == constants.SEG:
            tag = "_{}_{}".format(self.task, label)

        for j, (ids, imgs, masks, labels, tags, _) in enumerate(data_loader):
            with torch.no_grad():
                imgs = imgs.to(device)

                # 2. compute the histograms for matching.
                if args.use_dist_global_hist:
                    nbrs, c, h, w = imgs.shape

                    if args.smooth_img:
                        if gaussian_smoother is None:
                            gaussian_smoother = GaussianSmoothing(
                                channels=c,
                                kernel_size=args.smooth_img_ksz,
                                sigma=args.smooth_img_sigma,
                                dim=2,
                                exact_conv=True,
                                padding_mode='reflect').to(device)
                        # smooth the image.
                        imgs = gaussian_smoother(imgs)

                    re_imgs = imgs.view(nbrs * c, h * w)
                    hists_j = histc(re_imgs)  # nbrs * c, nbr_bins
                    # normalize to prob. dist
                    hists_j = hists_j + epsilon
                    hists_j = hists_j / hists_j.sum(dim=-1).unsqueeze(1)
                    hists_j = hists_j.view(nbrs, c, -1).cpu()

                    with open(join(outd, "histj_{}{}.pkl".format(j, tag)),
                              "wb") as fhist:
                        pkl.dump(hists_j, fhist, protocol=pkl.HIGHEST_PROTOCOL)

                # store some stuff.
                idss.extend(ids)
                labelss.extend(labels.numpy().tolist())
                # can't fit in memory
                # list_projections.append(copy.deepcopy(pr_j))

        # list_projections = [pr.to(device) for pr in list_projections]
        # compute sim.
        for i in tqdm.tqdm(range(nbr_samples), ncols=80, total=nbr_samples):
            id_sr, img_sr, mask_sr, label_sr, tag_sr, _ = dataset[i]
            if img_sr.ndim == 2:
                img_sr = img_sr.view(1, 1, img_sr.size()[0], img_sr.size()[1])
            elif img_sr.ndim == 3:
                img_sr = img_sr.view(1,
                                     img_sr.size()[0],
                                     img_sr.size()[1],
                                     img_sr.size()[2])
            else:
                raise ValueError('Unexpected dim: {}.'.format(img_sr.ndim))

            img_sr = img_sr.to(device)
            # histo
            histo_trg = None
            if args.use_dist_global_hist:
                if args.smooth_img:
                    img_sr = gaussian_smoother(img_sr)

                nbrs, c, h, w = img_sr.shape  # only one image.-> nbrs=1
                histo_trg = histc(img_sr.view(nbrs * c, h * w))  # c,
                # nbrbins.
                # normalize to prob. dist
                histo_trg = histo_trg + epsilon
                histo_trg = histo_trg / histo_trg.sum(dim=-1).unsqueeze(1)

            dists = None
            histo_prox = None
            for j in range(nbr_batches):
                with torch.no_grad():

                    # 2. histo proximity =======================================
                    if args.use_dist_global_hist:
                        with open(join(outd, "histj_{}{}.pkl".format(j, tag)),
                                  "rb") as fhisto:
                            hists_j = pkl.load(fhisto).to(device)  # bsize,
                            # c, nbrbins.

                        bs_sr, c_sr, nbr_bn_sr = hists_j.shape
                        tmp = self.hist_prox(
                            trg_his=histo_trg.repeat(bs_sr, 1),
                            src_his=hists_j.view(bs_sr * c_sr, -1))  # =>
                        # bs_sr * sr_c.
                        tmp = tmp.view(bs_sr, c_sr)

                        # tmp = self.sim(x, pr_j.to(device))
                        if tmp.ndim == 0:  # case of grey images with batch
                            # size of 1.
                            tmp = tmp.view(1, 1)

                        if histo_prox is None:
                            histo_prox = tmp
                        else:
                            histo_prox = torch.cat((histo_prox, tmp), dim=0)

            proximity = None
            if dists is not None:
                dists = dists.squeeze()  # remove the 1 dim. it happens when
                # batch_size == 1.
                dists = dists.cpu()
                proximity = dists.view(-1, 1)

            if histo_prox is not None:
                histo_prox = histo_prox.cpu()
            # shapes: dists: n. histo_prox: n, c where c is the number of
            # plans in the images.

            if args.use_dist_global_hist:
                # proximity = [l2 dist, r, g, b] or [l2 dist, grey]
                if proximity is not None:
                    proximity = torch.cat((proximity, histo_prox), dim=1)
                else:
                    proximity = histo_prox

            z += proximity.sum(dim=0)
            # store sims.
            srt, idx = torch.sort(proximity.sum(dim=1).squeeze(),
                                  descending=False)

            msg = "ERROR: {}".format(proximity[idx[0]].sum())
            # floating point issue: 1.1920928955078125e-07.
            # assert proximity[idx[0]].sum() == 0., msg

            label_pred = labelss[idx[1]]  # take the second because the first
            # is 0.
            # it is ok to overload the disc to avoid runtime cost.
            stats = {
                'id_sr': id_sr,  # id source
                'label_sr': label_sr,  # label source
                'label_pred': label_pred,
                'nearest_id': idss[idx[1]],  # closest sample.
                'proximity': proximity,
                'index_sort': idx  # so we do not have to sort again. [ok]
            }
            # name of the file: id_idNearest. this allows to get the id of
            # the nearest sample without reading the file. this speeds up the
            # pairing by avoiding disc access.
            id_nearest = stats['nearest_id']

            torch.save(proximity, join(outd, '{}.pt'.format(id_sr)))
            acc_label_prop += (label_sr == label_pred) * 1.

            if args.task == constants.SEG:
                msg = 'for weakly.sup.seg, all samples of the data provided' \
                      'to this function must have the same label. it does ' \
                      'not seem the case.W'
                assert label_sr == label_pred, msg

        # Cleaning.
        for j in range(nbr_batches):
            path1 = join(outd, "histj_{}{}.pkl".format(j, tag))
            path2 = join(outd, "histj_{}{}.pkl".format(j, tag))
            for path in [path1, path2]:
                if os.path.isfile(path):
                    os.remove(path)

        # store accuracy: the upper bound perf (when every sample is labeled
        # except one). this is useful only for classification task only.
        shared_stats = {
            'idss': idss,
            'labelss': labelss,
            'acc': 100. * acc_label_prop / nbr_samples,
            'z': z.cpu()
        }
        with open(join(outd, 'shared-stats{}.pkl'.format(tag)), 'wb') as fout:
            pkl.dump(shared_stats, fout, protocol=pkl.HIGHEST_PROTOCOL)

        if args.task == constants.SEG:
            msg = 'for weakly.sup.seg, accuracy is expected to be 100%. but' \
                  'found {}'.format(shared_stats['acc'])
            assert shared_stats['acc'] == 100., msg

        announce_msg('Upper bound classification accuracy: {}%'.format(
            shared_stats['acc']))
        announce_msg('Z: {}'.format(z))
Exemple #14
0
def get_dataset(dataset_name):
    """
    Get the command line by changing the hyper-parameters of:
    1. dataset
    :return: list of commands.
    """
    keys = dict()
    # constants.GLAS
    # constants.CUB
    # constants.OXF

    keys["dataset"] = [dataset_name]

    announce_msg("Generate configs for {} dataset".format(keys["dataset"]))

    assert len(keys["dataset"]) == 1, "We work with only one dataset." \
                                      "....[NOT OK]"

    if keys['dataset'][0] == constants.GLAS:
        t = 67.
        keys['p_init_samples'] = [100 * 8 / t]  # 5 examples per class. total
        # train 67. with 29 benign, and 38 malignant. (split 0, fold 0). this
        # will give the total selected samples 10 --> 4 per class.

        keys['p_samples'] = [100 * 2 / t]  # this will add 2 samples in
        # total each round --> 1 sample per class.

        keys['max_al_its'] = [25]  # todo. compute it.

    if keys['dataset'][0] == constants.CUB:
        t = 4794.
        keys['p_init_samples'] = [100 * 200 * 1 / t]  # 1 examples per class.
        # total train 4794. (split 0, fold 0). this
        # will give the total selected samples 200 * 1 --> 1 per class.

        keys['p_samples'] = [100 * 200 / t]  # this will add 200 samples in
        # total each round --> 1 sample per class.

        keys['max_al_its'] = [20]  # todo. compute it.

    if keys['dataset'][0] == constants.OXF:
        t = 1020.
        keys['p_init_samples'] = [100 * 102 * 1 / t]  # 1 examples per class.
        # total train 1020.  (split 0, fold 0). this
        # will give the total selected samples 102 * 1 --> 1 per class.

        keys['p_samples'] = [100 * 102 / t]  # this will add 102 samples in
        # total each round --> 1 sample per class.

        keys['max_al_its'] = [9]  # todo. compute it.

    if keys['dataset'][0] == constants.CAM16:
        t = 12174.  # trainset: tumor.
        keys['p_init_samples'] = [100 * 30 * 1 / t]  # 1 examples per class.
        # total train 12174.  (split 0, fold 0). this
        # will give the total selected samples 4 * 1 --> 4 per class.

        keys['p_samples'] = [100 * 1 / t]  # this will add 1 samples in
        # total each round --> 1 sample per class.

        keys['max_al_its'] = [30]  # todo. compute it.

    if keys['dataset'][0] in [
            constants.GLAS, constants.CUB, constants.OXF, constants.CAM16
    ]:  # segmentation
        keys['task'] = [constants.SEG]
    else:
        raise ValueError('Dataset {} with unknown task.'.format(
            keys['datasets']))

    # for everyone
    keys["split"] = [0]
    keys["fold"] = [0]

    llists = [keys[k] for k in keys.keys()]
    namekeys = [k for k in keys.keys()]

    return get_combos(llists, namekeys)
def final_validate(model,
                   dataloader,
                   criterion,
                   device,
                   dataset,
                   outd,
                   args,
                   log_file=None,
                   name_set="",
                   pseudo_labeled=False,
                   store_results=False,
                   apply_selection_tech=False
                   ):
    """
    Perform a final evaluation of a set.
    (images do not have the same size, so we can't stack them in one tensor).
    Validation samples may be large to fit all in the GPU at once.

    This is similar to validate() but it can operate on all types of
    datasets (train, valid, test). It selects only L. Also, it does some other
    operations related to drawing and final training tasks.

    :param outd: str, output directory of this dataset.
    :param name_set: str, name to indicate which set is being processed. e.g.:
           trainset, validset, testset.
    :param pseudo_labeled: bool. if true, the dataloader is loading the set
           of samples that has been pseudo-labeled. this is the case only for
           our method. if false, the samples are fully labeled. this allows to
          measure the performance over the pseudo-labeled samples separately.
    :param store_results: bool (default: False). if true, results (
           predictions) are stored. Useful for test set to store predictions.
           can be disabled for train/valid sets.
    :param apply_selection_tech: bool. if true, we compute stats that will be
           used for sample selection for the nex AL round. useful on the
           leftovers.
    """
    set_default_seed()

    outd_data = join(outd, "prediction")
    if not os.path.exists(outd_data):
        os.makedirs(outd_data)

    # TODO: rescale depending on the dataset.
    visualisor = VisualsePredSegmentation(height_tag=args.height_tag,
                                          show_tags=True,
                                          scale=0.5
                                          )

    model.eval()
    criterion.eval()
    metrics = Metrics(threshold=args.seg_threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    # to track stats.
    keys_perf = ["acc", "dice_idx"]
    losses_kz = ["total_loss", "cl_loss", "seg_l_loss", "seg_lp_loss"]
    per_sample_kz = ["pred_labels", "pred_masks", "ids"]
    keys = keys_perf + losses_kz + per_sample_kz
    tracker = dict()
    for k in keys:
        if k not in per_sample_kz:
            tracker[k] = 0.
        else:
            tracker[k] = []

    n_samples = 0.
    n_sam_dice = 0.
    # ignoring samples is on only: 1. cam16 dataset. 2. wsl method.
    # for the rest of the methods, normal samples are completely dropped before
    # starting training.
    cnd_dice = (args.dataset == constants.CAM16)
    cnd_dice &= (args.al_type == constants.AL_WSL)

    # Selection criteria.
    cp_entropy = Entropy()  # to copmute entropy.
    entropy = []
    mcdropout_var = []

    t0 = dt.datetime.now()

    # conditioning
    entropy_cnd = (args.al_type == constants.AL_ENTROPY)
    entropy_cnd = entropy_cnd or ((args.al_type == constants.AL_LP) and (
            args.clustering == constants.CLUSTER_ENTROPY))
    mcdropout_cnd = (args.al_type == constants.AL_MCDROPOUT)

    # cnd_asrt = (args.task == constants.SEG)
    cnd_asrt = store_results
    cnd_asrt |= apply_selection_tech

    with torch.no_grad():
        for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            targets = None
            bsz = data.size()[0]

            if cnd_asrt:
                msg = "batchsize must be 1 for segmentation task. " \
                      "found {}.".format(bsz)
                assert bsz == 1, msg

            data = data.to(device)
            labels = label.to(device)
            masks_trg = mask.to(device)

            scores, masks_pred, maps = model(x=data, seed=None)

            # change to the predicted mask to be the maps if WSL.
            if args.al_type == constants.AL_WSL:
                lb_ = labels.view(-1, 1, 1, 1)
                masks_pred = (maps.argmax(dim=1, keepdim=True) == lb_).float()

            # resize pred mask if sizes mismatch.
            if masks_pred.shape != masks_trg.shape:
                _, _, oh, ow = masks_trg.shape
                masks_pred = dataset.turnback_mask_tensor_into_original_size(
                    pred_masks=masks_pred,
                    oh=oh,
                    ow=ow
                )

            if args.al_type != constants.AL_WSL:
                # TODO: change normalization location. (future)
                masks_pred = torch.sigmoid(masks_pred)  # binary segmentation

            losses_ = criterion(scores=scores,
                                labels=labels,
                                targets=targets,
                                masks_pred=masks_pred.view(bsz, -1),
                                masks_trg=masks_trg.view(bsz, -1),
                                tags=tag,
                                weights=None,
                                avg=False
                                )

            # metrics
            ignore_dice = None
            if cnd_dice:
                # ignore samples with label 'normal' (0) when computing dice.
                ignore_dice = (labels == 0).float().view(-1)

            metrx = metrics(scores=scores,
                            labels=labels,
                            tr_loss=criterion,
                            masks_pred=masks_pred.view(bsz, -1),
                            masks_trg=masks_trg.view(bsz, -1),
                            avg=False,
                            ignore_dice=ignore_dice
                            )
            tracker["acc"] += metrx[0].item()
            tracker["dice_idx"] += metrx[1].item()

            for j, los in enumerate(losses_):
                tracker[losses_kz[j]] += los.item()

            # stored things
            # always store ids.
            tracker["ids"].extend(ids)

            if store_results:
                tracker["pred_labels"].extend(
                    scores.argmax(dim=1, keepdim=False).cpu().numpy().tolist()
                )
                # store binary masks
                tracker["pred_masks"].extend(
                    [(metrics.get_binary_mask(
                        masks_pred[kk]).detach().cpu().numpy() > 0.) for kk in
                     range(bsz)]
                )

            n_samples += bsz
            if cnd_dice:
                n_sam_dice += ignore_dice.numel() - ignore_dice.sum()
            else:
                n_sam_dice += bsz

            # ==================================================================
            #                    START: COMPUTE INFO. FOR SELECTION CRITERIA.
            # ==================================================================
            # 1. Entropy
            if entropy_cnd and apply_selection_tech:
                if args.task == constants.CL:
                    entropy.extend(
                        cp_entropy(
                            F.softmax(scores, dim=1)).cpu().numpy().tolist()
                    )
                elif args.task == constants.SEG:
                    assert bsz == 1, "batchsize must be 1. " \
                                     "found {}.".format(bsz)
                    # nbr_pixels, 2 (forg, backg).
                    pixel_probs = torch.cat((masks_pred.view(-1, 1),
                                             masks_pred.view(-1, 1)), dim=1)
                    avg_pixel_entropy = cp_entropy(pixel_probs).mean()
                    entropy.append(avg_pixel_entropy.cpu().item())
                else:
                    raise ValueError("Unknown task {}.".format(args.task))

            # 1. MC-Dropout
            if mcdropout_cnd and apply_selection_tech:

                reset_seed(int(os.environ["MYSEED"]))

                if args.task == constants.CL:
                    # todo
                    raise NotImplementedError
                elif args.task == constants.SEG:
                    assert bsz == 1, "batchsize must be 1. " \
                                     "found {}.".format(bsz)
                    stacked_masks = None
                    # turn on dropout
                    model.set_dropout_to_train_mode()
                    for it_mc in range(args.mcdropout_t):
                        scores, masks_pred, maps = model(x=data, seed=None)

                        # resize pred mask if sizes mismatch.
                        if masks_pred.shape != masks_trg.shape:
                            _, _, oh, ow = masks_trg.shape
                            masks_pred = \
                                dataset.turnback_mask_tensor_into_original_size(
                                    pred_masks=masks_pred, oh=oh, ow=ow
                                )

                        # TODO: change normalization location. (future)
                        masks_pred = torch.sigmoid(masks_pred)
                        # stack flatten masks horizontally.
                        if stacked_masks is None:
                            stacked_masks = masks_pred.view(-1, 1)
                        else:
                            stacked_masks = torch.cat(
                                (stacked_masks, masks_pred.view(-1, 1)),
                                dim=0
                            )
                    # compute variance per pixel, then average over the
                    # image. images have different sizes. if it is not the
                    # case, it is fine since we divide all sum_var with a
                    # constant.
                    variance = stacked_masks.var(
                        dim=0, unbiased=True).mean()
                    mcdropout_var.append(variance.cpu().item())

                    # turn off dropout
                    model.set_dropout_to_eval_mode()
                else:
                    raise ValueError(
                        "Unknown task {}.".format(args.task))
            # ==================================================================
            #                    END: COMPUTE INFO. FOR SELECTION CRITERIA.
            # ==================================================================

            # Visualize
            cnd = (args.task == constants.SEG)
            cnd &= (args.subtask == constants.SUBCLSEG)
            cnd &= store_results

            # todo: separate storing stats from plotting figure + storing it.
            # todo: add plot_figure var.
            if cnd:
                output_file = join(outd_data, "{}.jpeg".format(ids[0]))

                visualisor(
                    img_in=dataset.get_original_input_img(i),
                    mask_pred=metrics.get_binary_mask(
                        masks_pred).detach().cpu().squeeze().numpy(),
                    true_label=dataset.get_original_input_label_int(i),
                    label_pred=scores.argmax(dim=1, keepdim=False).cpu().numpy(
                    ).tolist()[0],
                    id_sample=ids[0],
                    name_classes=dataset.name_classes,
                    true_mask=masks_trg.cpu().squeeze().numpy(),
                    dice=metrx[1].item(),
                    output_file=output_file,
                    scale=None,
                    binarize_pred_mask=False,
                    cont_pred_msk=masks_pred.detach().cpu().squeeze().numpy()
                )

            # clear memory
            del scores
            del masks_pred
            del maps

    # avg: acc, dice_idx
    for k in keys_perf:
        if (k == 'dice_idx') and cnd_dice:
            announce_msg("Ignore normal samples mode. "
                         "Total samples left: {}/Total: {}.".format(
                n_sam_dice, n_samples))
            if n_sam_dice != 0:
                tracker[k] /= float(n_sam_dice)
            else:
                tracker[k] = 0.

        else:
            tracker[k] /= float(n_samples)

    # compress, then delete files to prevent overloading the disc quota of
    # number of files.
    def compress_del(src):
        """
        Compress a folder with name 'src' into 'src.zip' or 'src.tar.gz'.
        Then, delete the folder 'src'.
        :param src: str, absolute path to the folder to compress.
        :return:
        """
        try:
            cmd_compress = 'zip -rjq {}.zip {}'.format(src, src)
            print("Run: `{}`".format(cmd_compress))
            subprocess.run(cmd_compress, shell=True, check=True)
        except subprocess.CalledProcessError:
            cmd_compress = 'tar -zcf {}.tar.gz -C {} .'.format(src, src)
            print("Run: `{}`".format(cmd_compress))
            subprocess.run(cmd_compress, shell=True, check=True)

        cmd_del = 'rm -r {}'.format(src)
        print("Run: `{}`".format(cmd_del))
        os.system(cmd_del)

    compress_del(outd_data)

    to_write = "EVAL.FINAL {} -- pseudo-labeled {}: ACC: {:.3f}%," \
               " DICE: {:.3f}%, time:{}".format(
                name_set, pseudo_labeled, tracker["acc"] * 100.,
                tracker["dice_idx"] * 100., dt.datetime.now() - t0
               )
    to_write = "{} \n{} \n{}".format(10 * "=", to_write, 10 * "=")
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # store the stats in pickle.
    final_stats = dict()  # for a quick access: perf.
    for k in keys_perf:
        final_stats[k] = tracker[k] * 100.

    pred_stats = dict()  # it is heavy.
    for k in losses_kz + per_sample_kz:
        pred_stats[k] = tracker[k]

    pred_stats["pseudo-labeled"] = pseudo_labeled

    if pseudo_labeled:
        outfile_tracker = join(outd, 'final-tracker-{}-pseudoL-{}.pkl'.format(
            name_set, pseudo_labeled))
        outfile_pred = join(outd, 'final-pred-{}-pseudoL-{}.pkl'.format(
            name_set, pseudo_labeled))
    else:
        outfile_tracker = join(outd, 'final-tracker-{}.pkl'.format(name_set))
        outfile_pred = join(outd, 'final-pred-{}.pkl'.format(name_set))

    with open(outfile_tracker, 'wb') as fout:
        pkl.dump(final_stats, fout, protocol=pkl.HIGHEST_PROTOCOL)
    with open(outfile_pred, 'wb') as fout:
        pkl.dump(pred_stats, fout, protocol=pkl.HIGHEST_PROTOCOL)

    # store info. for selection techniques.
    # 1. Entropy.
    if entropy_cnd and apply_selection_tech:
        with open(join(outd, 'entropy-{}.pkl'.format(name_set)), 'wb') as fout:
            pkl.dump({'entropy': entropy,
                      'ids': tracker["ids"]}, fout,
                     protocol=pkl.HIGHEST_PROTOCOL)

    # 2. MC-Dropout.
    if mcdropout_cnd and apply_selection_tech:
        with open(join(outd, 'mc-dropout-{}.pkl'.format(name_set)),
                  'wb') as fout:
            pkl.dump({'mc-dropout-var': mcdropout_var,
                      'ids': tracker["ids"]}, fout,
                     protocol=pkl.HIGHEST_PROTOCOL)

    set_default_seed()
def compute_similarities(args,
                         tag_sims,
                         train_csv,
                         rootpath,
                         DEVICE,
                         SIMS,
                         training_log,
                         placement_node,
                         parent
                         ):
    """
    Compute similarities.
    :return:
    """
    # drop normal samples and keep metastatic if: 1. dataset=CAM16. 2.
    # al_type != AL_WSL.
    cnd_drop_n = (args.dataset == constants.CAM16)
    cnd_drop_n &= (args.al_type != constants.AL_WSL)

    if args.al_type != constants.AL_LP:  # get out.
        return 0

    # 1. compute sims
    current_dir = dirname(abspath(__file__))

    # compute proximity
    if not os.path.exists(
            join("pairwise_sims", '{}.tar.gz'.format(tag_sims))):
        announce_msg("Going to project samples, and compute apirwise "
                     "similarities")

        all_train_samples = csv_loader(train_csv,
                                       rootpath,
                                       drop_normal=cnd_drop_n
                                       )
        for ii, el in enumerate(all_train_samples):
            el[4] = constants.L  # just for the loader consistency.
            # masks are not used when computing the pairwise similarity.

        set_default_seed()
        compute_sim = PairwiseSimilarity(task=args.task)
        set_default_seed()
        t0 = dt.datetime.now()
        if args.task == constants.CL:
            set_default_seed()
            compute_sim(data=all_train_samples, args=args, device=DEVICE,
                        outd=SIMS)
            set_default_seed()
        elif args.task == constants.SEG:
            # it has to be done differently. the similarity is measured
            # only between samples within the same class.

            for k in args.name_classes.keys():
                samples_in_same_class = [
                    sx for sx in all_train_samples if sx[3] == k]
                print("Computing similarities for class {}:".format(k))
                set_default_seed()
                compute_sim(data=samples_in_same_class, args=args,
                            device=DEVICE, outd=SIMS, label=k)
                set_default_seed()

        msg = "Time to compute sims {}: {}".format(
            tag_sims, dt.datetime.now() - t0
        )
        print(msg)
        log(training_log, msg)

        # compress, move files.

        if "CC_CLUSTER" in os.environ.keys():  # if CC
            cmdx = "cd {} && " \
                   "cd .. && " \
                   "tar -cf {}.tar.gz {} && " \
                   "cp {}.tar.gz {} && " \
                   "cd {} ".format(
                    SIMS,
                    tag_sims,
                    tag_sims,
                    tag_sims,
                    join(current_dir, "pairwise_sims"),
                    current_dir
                    )
        else:
            cmdx = "cd {} && " \
                   "tar -cf {}.tar.gz {} && " \
                   "cd {} ".format(
                    "./pairwise_sims",
                    tag_sims,
                    tag_sims,
                    current_dir
                    )

        tt = dt.datetime.now()
        print("Running bash-cmds: \n{}".format(cmdx.replace("&& ", "\n")))
        subprocess.run(cmdx, shell=True, check=True)
        msg += "\n time to run the command {}: {}".format(
            cmdx, dt.datetime.now() - tt)
        print(msg)
        log(training_log, msg)

    else:  # unzip if necessary.
        cmdx = None
        if "CC_CLUSTER" in os.environ.keys():  # if CC, copy to node.
            pr = join(placement_node, parent, "pairwise_sims")
            folder = join(pr, tag_sims)
            uncomp = False
            if not os.path.exists(folder):
                uncomp = True
            else:
                if len(os.listdir(folder)) == 0:
                    uncomp = True
            if uncomp:
                cmdx = "cp {}/{}.tar.gz {} && " \
                       "cd {} && " \
                       "tar -xf {}.tar.gz && " \
                       "cd {} ".format(
                            "./pairwise_sims",
                            tag_sims,
                            pr,
                            pr,
                            tag_sims,
                            current_dir
                            )

        else:
            folder = join('./pairwise_sims', tag_sims)
            uncomp = False
            if not os.path.exists(folder):
                uncomp = True
            else:
                if len(os.listdir(folder)) == 0:
                    uncomp = True

            if uncomp:
                cmdx = "cd {} && " \
                       "tar -xf {}.tar.gz && " \
                       "cd {} ".format(
                            "./pairwise_sims",
                            tag_sims,
                            current_dir
                            )

        if cmdx is not None:
            tt = dt.datetime.now()
            print("Running bash-cmds: \n{}".format(cmdx.replace("&& ", "\n")))
            subprocess.run(cmdx, shell=True, check=True)
            msg = "runtime of ALL the bash-cmds: {}".format(
                dt.datetime.now() - tt)
            print(msg)
            log(training_log, msg)

    return 0
def train_one_epoch(model,
                    optimizer,
                    dataloader,
                    criterion,
                    device,
                    tr_stats,
                    args,
                    dataset,
                    epoch=0,
                    cycle=0,
                    log_file=None,
                    ALLOW_MULTIGPUS=False,
                    NBRGPUS=1
                    ):
    """
    Perform one epoch of training.
    :param model: instance of a model.
    :param optimizer: instance of an optimizer.
    :param dataloader: list of two instance of a dataloader: L u L`, U.
    :param criterion: instance of a learning criterion.
    :param device: a device.
    :param tr_stats: dict that holds the states of the training. or
    None.
    :param args: args of the main.py
    :param dataset: the dataset.
    :param epoch: int, the current epoch.
    :param cycle: int, the current cylce.
    :param log_file: a logfile.
    :param ALLOW_MULTIGPUS: bool. If True, we are in multiGPU mode.
    :param NBRGPUS: int, number of GPUs.
    :return:
    """
    reset_seed(seedj(epoch, 3, cycle, CONST1))

    model.train()
    criterion.train()
    metrics = Metrics(threshold=args.seg_threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    t0 = dt.datetime.now()
    # to track stats.
    keys = ["acc", "dice_idx"]
    losses_kz = ["total_loss", "cl_loss", "seg_l_loss", "seg_lp_loss"]
    keys = keys + losses_kz
    tracker = dict()
    for k in keys:
        tracker[k] = 0.

    n_samples = 0.
    n_sam_dice = 0.
    # ignoring samples is on only: 1. cam16 dataset. 2. wsl method.
    # for the rest of the methods, normal samples are completely dropped before
    # starting training.
    cnd_dice = (args.dataset == constants.CAM16)
    cnd_dice &= (args.al_type == constants.AL_WSL)

    for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
            enumerate(dataloader), ncols=80, total=length):
        seedx = int(os.environ["MYSEED"]) + epoch + (cycle + 1) * 10
        reset_seed(seedx)

        targets = None
        masks_pred, masks_trg = None, None
        if mask is not None:
            masks_trg = mask.to(device)

        if (args.al_type == constants.AL_LP) and (args.task == constants.CL):
            targets = build_target_cl(args=args, labels=label, tags=tag,
                                      dataset=dataset, device=device)
            reset_seed(seedx)

        seedx += i

        data = data.to(device)
        labels = label.to(device)

        if targets is not None:
            targets = targets.to(device)

        bsz = data.size()[0]

        model.zero_grad()
        prngs_cuda = None

        # Optimization:
        if not ALLOW_MULTIGPUS:
            if "CC_CLUSTER" in os.environ.keys():
                msg = "Something wrong. You deactivated multigpu mode, " \
                      "but we find {} GPUs. This will not guarantee " \
                      "reproducibility. We do not know why you did that. " \
                      "Exiting ... [NOT OK]".format(NBRGPUS)
                assert NBRGPUS <= 1, msg
            seeds_threads = None
        else:
            msg = "Something is wrong. You asked for multigpu mode. But, " \
                  "we found {} GPUs. Exiting .... [NOT OK]".format(NBRGPUS)
            assert NBRGPUS > 1, msg
            # The seeds are generated randomly before calling the threads.
            reset_seed(seedx)
            seeds_threads = torch.randint(
                0, np.iinfo(np.uint32).max + 1, (NBRGPUS, )).to(device)
            reset_seed(seedx)
            prngs_cuda = []
            # Create different prng states of cuda before forking.
            for seed in seeds_threads:
                # get the corresponding state of the cuda prng with respect to
                # the seed.
                inter_seed = seed.cpu().item()
                # change the internal state of the prng to a random one using
                # the random seed so to capture it.
                torch.manual_seed(inter_seed)
                torch.cuda.manual_seed(inter_seed)
                # capture the prng state.
                prngs_cuda.append(torch.cuda.get_rng_state())
            reset_seed(seedx)

        if prngs_cuda is not None and prngs_cuda != []:
            prngs_cuda = torch.stack(prngs_cuda)

        reset_seed(seedx)
        scores, masks_pred, maps = model(x=data,
                                         seed=seeds_threads,
                                         prngs_cuda=prngs_cuda
                                        )
        reset_seed(seedx)

        # change to the predicted mask to be the maps if WSL.
        if args.al_type == constants.AL_WSL:
            lb_ = labels.view(-1, 1, 1, 1)
            masks_pred = (maps.argmax(dim=1, keepdim=True) == lb_).float()
            # masks_pred contains binary values with float format.
            # shape (bsz, 1, h, w)
        else:
            # TODO: change normalization location. (future)
            masks_pred = torch.sigmoid(masks_pred)  # binary segmentation.


        losses = criterion(scores=scores,
                           labels=labels,
                           targets=targets,
                           masks_pred=masks_pred.view(bsz, -1),
                           masks_trg=masks_trg.view(bsz, -1),
                           tags=tag,
                           weights=None,
                           avg=True
                           )
        reset_seed(seedx)

        losses[0].backward()  # total loss.
        reset_seed(seedx)
        # Update params.
        optimizer.step()
        reset_seed(seedx)
        # End optimization.

        # metrics

        ignore_dice = None
        if cnd_dice:
            # ignore samples with label 'normal' (0) when computing dice.
            ignore_dice = (labels == 0).float().view(-1)


        metrx = metrics(scores=scores,
                        labels=labels,
                        tr_loss=criterion,
                        masks_pred=masks_pred.view(bsz, -1),
                        masks_trg=masks_trg.view(bsz, -1),
                        avg=False,
                        ignore_dice=ignore_dice
                        )

        # Update the tracker.
        tracker["acc"] += metrx[0].item()
        tracker["dice_idx"] += metrx[1].item()

        for j, los in enumerate(losses):
            tracker[losses_kz[j]] += los.item()

        n_samples += bsz
        if cnd_dice:
            n_sam_dice += ignore_dice.numel() - ignore_dice.sum()
        else:
            n_sam_dice += bsz

        # clear the memory
        del scores
        del masks_pred
        del maps

    # average.
    for k in keys:
        if k in losses_kz:
            tracker[k] /= float(length)
        elif (k == 'dice_idx') and cnd_dice:
            announce_msg("Ignore normal samples mode. "
                         "Total samples left: {}/Total: {}.".format(
                n_sam_dice, n_samples))
            if n_sam_dice == 0:
                tracker[k] = 0.
            else:
                tracker[k] /= float(n_sam_dice)
        else:
            tracker[k] /= float(n_samples)

    to_write = "Tr.Ep {:>2d}-{:>2d}: ACC: {:.2f}%, DICE: {:.2f}%, LR: {}, " \
               "time:{}".format(
        cycle,
        epoch,
        tracker["acc"] * 100.,
        tracker["dice_idx"] * 100.,
        ['{:.2e}'.format(group["lr"]) for group in optimizer.param_groups],
        dt.datetime.now() - t0
    )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # Update stats:
    if tr_stats is not None:
        for k in keys:
            tr_stats[k].append(tracker[k])
    else:
        # convert each element into a list
        for k in keys:
            tracker[k] = [tracker[k]]
        tr_stats = deepcopy(tracker)

    reset_seed(seedj(epoch, 4, cycle, CONST1))

    return tr_stats
    log(training_log, "\n\n ########### Training #########\n\n")
    log(results_log, "\n\n ########### Results #########\n\n")

    # ==========================================================
    # Data transformations: on PIL.Image.Image and torch.tensor.
    # ==========================================================

    train_transform_img = get_train_transforms_img(args)
    transform_tensor = get_transforms_tensor(args)

    # ==========================================================================
    # Datasets: load csv, datasets: train, valid, test.
    # ==========================================================================

    announce_msg("SPLIT: {} \t FOLD: {}".format(args.split, args.fold))

    train_csv, valid_csv, test_csv = get_csv_files(args)

    rootpath = get_rootpath_2_dataset(args)

    # drop normal samples and keep metastatic if: 1. dataset=CAM16. 2.
    # al_type != AL_WSL.
    cnd_drop_n = (args.dataset == constants.CAM16)
    cnd_drop_n &= (args.al_type != constants.AL_WSL)

    train_samples = csv_loader(train_csv, rootpath, drop_normal=cnd_drop_n)
    valid_samples = csv_loader(valid_csv, rootpath, drop_normal=cnd_drop_n)
    test_samples = csv_loader(test_csv, rootpath, drop_normal=cnd_drop_n)

    # remove normal from name classes.
Exemple #19
0
import csv

import constants
from vision import PlotActiveLearningRounds
from shared import announce_msg
from shared import compute_auc

mpl.style.use('seaborn')
# styles:
# https://matplotlib.org/3.1.1/tutorials/introductory/customizing.html
# https://matplotlib.org/3.3.0/gallery/style_sheets/style_sheets_reference.html

# choose the datset
dataset = constants.CUB

announce_msg("Processing dataset: {}".format(dataset))

show_percentage = True  # if true, the xlabels will be percentage instead of
# number of samples..

TAG = 'paper_label_prop/{}'.format(dataset)

# colors: https://matplotlib.org/3.1.0/gallery/color/named_colors.html
colors = {
    constants.AL_FULL_SUP: mcolors.CSS4_COLORS['black'],
    constants.AL_LP: mcolors.CSS4_COLORS['red'],
    constants.AL_ENTROPY: mcolors.CSS4_COLORS['sienna'],
    constants.AL_RANDOM: mcolors.CSS4_COLORS['blue'],
    constants.AL_MCDROPOUT: mcolors.CSS4_COLORS['green'],
    constants.AL_WSL: mcolors.CSS4_COLORS['orange']
}
Exemple #20
0
# them
thread_lock = threading.Lock()
# thread-safe.

import reproducibility
import constants

ACTIVATE_SYNC_BN = True
# Override ACTIVATE_SYNC_BN using variable environment in Bash:
# $ export ACTIVATE_SYNC_BN="True"   ----> Activate
# $ export ACTIVATE_SYNC_BN="False"   ----> Deactivate

if "ACTIVATE_SYNC_BN" in os.environ.keys():
    ACTIVATE_SYNC_BN = (os.environ['ACTIVATE_SYNC_BN'] == "True")

announce_msg("ACTIVATE_SYNC_BN was set to {}".format(ACTIVATE_SYNC_BN))

if check_if_allow_multgpu_mode() and ACTIVATE_SYNC_BN:  # Activate Synch-BN.
    from deeplearning.syncbn import nn as NN_Sync_BN
    BatchNorm2d = NN_Sync_BN.BatchNorm2d
    announce_msg("Synchronized BN has been activated. \n"
                 "MultiGPU mode has been activated. "
                 "{} GPUs".format(torch.cuda.device_count()))
else:
    BatchNorm2d = nn.BatchNorm2d
    if check_if_allow_multgpu_mode():
        announce_msg("Synchronized BN has been deactivated.\n"
                     "MultiGPU mode has been activated. "
                     "{} GPUs".format(torch.cuda.device_count()))
    else:
        announce_msg("Synchronized BN has been deactivated.\n"
from shared import check_if_allow_multgpu_mode, announce_msg
from deeplearning.utils import initialize_weights
from deeplearning.aspp import aspp

import reproducibility
import constants

ACTIVATE_SYNC_BN = False
# Override ACTIVATE_SYNC_BN using variable environment in Bash:
# $ export ACTIVATE_SYNC_BN="True"   ----> Activate
# $ export ACTIVATE_SYNC_BN="False"   ----> Deactivate

if "ACTIVATE_SYNC_BN" in os.environ.keys():
    ACTIVATE_SYNC_BN = (os.environ['ACTIVATE_SYNC_BN'] == "True")

announce_msg("ACTIVATE_SYNC_BN was set to {}".format(ACTIVATE_SYNC_BN))

if check_if_allow_multgpu_mode() and ACTIVATE_SYNC_BN:  # Activate Synch-BN.
    from deeplearning.syncbn import nn as NN_Sync_BN
    BatchNorm2d = NN_Sync_BN.BatchNorm2d
    announce_msg("Synchronized BN has been activated. \n"
                 "MultiGPU mode has been activated. "
                 "{} GPUs".format(torch.cuda.device_count()))
else:
    BatchNorm2d = nn.BatchNorm2d
    if check_if_allow_multgpu_mode():
        announce_msg("Synchronized BN has been deactivated.\n"
                     "MultiGPU mode has been activated. "
                     "{} GPUs".format(torch.cuda.device_count()))
    else:
        announce_msg("Synchronized BN has been deactivated.\n"