def get_trainset(args,
                 train_samples,
                 transform_tensor,
                 train_transform_img,
                 check_ps_msk_path,
                 previous_pairs,
                 fd_p_msks
                 ):
    """
    Get the trainset.
    :return:
    """
    set_default_seed()

    trainset = PhotoDataset(
        train_samples,
        args.dataset,
        args.name_classes,
        transform_tensor,
        set_for_eval=False,
        transform_img=train_transform_img,
        resize=None,
        resize_h_to=None,
        resize_mask=False,
        crop_size=args.crop_size,
        padding_size=(args.padding_ratio, args.padding_ratio),
        padding_mode=args.padding_mode,
        up_scale_small_dim_to=args.up_scale_small_dim_to,
        do_not_save_samples=True,
        ratio_scale_patch=args.ratio_scale_patch,
        for_eval_flag=False,
        scale_algo=args.scale_algo,
        enhance_color=False,
        enhance_color_fact=1.,
        check_ps_msk_path=check_ps_msk_path,
        previous_pairs=previous_pairs,
        fd_p_msks=fd_p_msks
    )

    set_default_seed()

    train_loader = DataLoader(
        trainset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        worker_init_fn=_init_fn,
        collate_fn=default_collate
    )

    set_default_seed()

    out = {
        'trainset': trainset,
        'train_loader': train_loader
    }

    return trainset, train_loader
def get_validationset(args,
                      valid_samples,
                      transform_tensor,
                      padding_size_eval,
                      batch_size=None
                      ):
    """
    Get the validation set
    :param batch_size: int or None. batch size. if None, the value defined in
    `args.valid_batch_size` will be used.
    :return:
    """
    set_default_seed()
    validset = PhotoDataset(
        valid_samples,
        args.dataset,
        args.name_classes,
        transform_tensor,
        set_for_eval=False,
        transform_img=None,
        resize=None,
        resize_h_to=None,
        resize_mask=False,
        crop_size=None,
        padding_size=padding_size_eval,
        padding_mode=None if (padding_size_eval == (None, None)) else
        args.padding_mode,
        up_scale_small_dim_to=args.up_scale_small_dim_to,
        do_not_save_samples=True,
        ratio_scale_patch=args.ratio_scale_patch,
        for_eval_flag=True,
        scale_algo=args.scale_algo,
        enhance_color=False,
        enhance_color_fact=1.,
        check_ps_msk_path=False,
        fd_p_msks=None
    )
    set_default_seed()

    # we need more workers since the batch size is 1, and set_for_eval is
    # False (need more time to prepare a sample).
    valid_loader = DataLoader(
        validset,
        batch_size=args.valid_batch_size if batch_size is None else batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=default_collate,
        worker_init_fn=_init_fn
    )
    set_default_seed()

    out = {
        'validset': validset,
        'valid_loader': valid_loader
    }

    return validset, valid_loader
def merge_pairs_tr_samples(args,
                           some_pairs,
                           tr_original,
                           ids_org,
                           train_samples
                           ):
    """
    merge some pairs intro the trainset.
    :return:
    """
    set_default_seed()

    acc_new_samples = 0.
    train_samples = deepcopy(train_samples)

    if args.al_type != constants.AL_LP:
        return train_samples, acc_new_samples


    # add the paired samples to the trainset: previous only.
    new_samples = []

    for k in list(some_pairs.keys()):
        # pairs: dict of: k (key: id of unlabeled sample): val (value: id
        # of labeled sample)
        idtoadd = k  # pairs[k]
        stoadd = deepcopy(tr_original[ids_org.index(idtoadd)])

        # previous pairs.
        acc_new_samples += (tr_original[ids_org.index(
            some_pairs[k])][3] == stoadd[3]) * 1.
        stoadd[4] = constants.PL  # to filter samples in the loss.
        # set the previously paired to pseudo-labeled.
        stoadd[3] = tr_original[ids_org.index(some_pairs[k])][3]
        # image-level label propagation. propagated labels are not
        # perfect.

        if args.task == constants.SEG:
            msg = "in weak.sup. setup, paired samples must have same " \
                  "image level label."
            cndx = tr_original[ids_org.index(some_pairs[k])][3] == stoadd[3]
            assert cndx, msg


        new_samples.append(stoadd)

    train_samples.extend(new_samples)  # add samples.
    # shuffle very well to mix pairs (new, old) with full sup.
    set_default_seed()
    for i in range(1000):
        random.shuffle(train_samples)
    set_default_seed()


    return train_samples, acc_new_samples
def clean_shared_masks(args,
                       SHARED_OPT_MASKS,
                       metrics_fd,
                       all_pairs
                       ):
    """
    Clean the folder shared_masks from duplicates.

    :param args:
    :param SHARED_OPT_MASKS:
    :param metrics_fd:
    :param all_pairs:
    :return:
    """
    if not args.share_masks:
        return 0

    # do some cleaning of the shared folder of masks.
    # some unlabeled samples are labeled. so, the files are no
    # longer useful. + their metrics
    l_pm_needed = [
        join(SHARED_OPT_MASKS, "{}-{}.bmp".format(
            id_u, all_pairs[id_u])) for id_u in all_pairs.keys()
    ]
    l_pm_exist = find_files_pattern(fd_in_=SHARED_OPT_MASKS, pattern_="*.bmp")
    set_default_seed()
    l_pm_del = list(set(l_pm_exist) - set(l_pm_needed))
    set_default_seed()

    l_mtr_needed = [
        join(metrics_fd, "{}-{}.pkl".format(
            id_u, all_pairs[id_u])) for id_u in all_pairs.keys()
    ]
    l_mtr_exist = find_files_pattern(fd_in_=metrics_fd, pattern_="*.pkl")
    set_default_seed()
    l_mtr_del = list(set(l_mtr_exist) - set(l_mtr_needed))
    set_default_seed()

    [os.remove(path) for path in l_pm_del + l_mtr_del]

    return 0
def get_dataset_for_pseudo_anno(args,
                                new_samples,
                                transform_tensor,
                                padding_size_eval
                                ):
    """
    Get dataset for pseudo-annotation.
    :return:
    """
    set_default_seed()
    trainset_eval = PhotoDataset(
        new_samples,
        args.dataset,
        args.name_classes,
        transform_tensor,
        set_for_eval=False,
        transform_img=None,
        resize=None,
        resize_h_to=None,
        resize_mask=False,
        crop_size=None,
        padding_size=padding_size_eval,
        padding_mode=None if (padding_size_eval == (None, None)) else
        args.padding_mode,
        up_scale_small_dim_to=args.up_scale_small_dim_to,
        do_not_save_samples=True,
        ratio_scale_patch=args.ratio_scale_patch,
        for_eval_flag=True,
        scale_algo=args.scale_algo,
        enhance_color=False,
        enhance_color_fact=1.,
        check_ps_msk_path=False,
        fd_p_msks=None
    )

    set_default_seed()
    train_eval_loader = DataLoader(
        trainset_eval,
        batch_size=args.valid_batch_size if args.task == constants.CL else 1,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        collate_fn=default_collate,
        worker_init_fn=_init_fn
    )

    set_default_seed()

    return trainset_eval, train_eval_loader
def pair_samples(args,
                 train_samples,
                 tr_leftovers,
                 SIMS,
                 previous_pairs,
                 fd_p_msks
                 ):
    """
    Pair samples.
    :return:
    """
    VERBOSE = False

    acc_new_samples = 0.
    nbrx = 0
    train_samples_before_merging = deepcopy(train_samples)
    pairs = dict()
    metrics_fd = join(fd_p_msks, "metrics")  # inside the fd of the masks.

    if args.al_type == constants.AL_LP:
        set_default_seed()
        pairmaker = PairSamples(task=args.task,
                                knn=args.knn
                                )
        set_default_seed()
        announce_msg("starts pairing...")
        pairs = pairmaker(train_samples, tr_leftovers, SIMS)
        announce_msg("finishes pairing...")
        set_default_seed()

        # if pair exists in the previous round, delete it.
        # pairs must contain only newly paired samples or samples that have
        # been paired but changed their source.

        for k in pairs.keys():
            if (k, pairs[k]) in previous_pairs:  # same pair exist.
                pairs.pop(k)

    return pairs, acc_new_samples, nbrx,train_samples_before_merging
            os.path.isfile(valid_csv),
            os.path.isfile(test_csv)
    ]):
        raise ValueError("Missing *.cvs files ({}[{}], {}[{}], {}[{}])".format(
            train_csv, os.path.isfile(train_csv), valid_csv,
            os.path.isfile(valid_csv), test_csv, os.path.isfile(test_csv)))

    rootpath = get_rootpath_2_dataset(args)

    train_samples = csv_loader(train_csv, rootpath)
    valid_samples = csv_loader(valid_csv, rootpath)
    test_samples = csv_loader(test_csv, rootpath)

    # Just for debug to go fast.
    if DEBUG_MODE:
        set_default_seed()
        warnings.warn("YOU ARE IN DEBUG MODE!!!!")
        if args.dataset == "Caltech-UCSD-Birds-200-2011":
            nbrx_tr, nbrx_vl, nbrx_tst = 100, 5, 20
        elif args.dataset == "Oxford-flowers-102":
            nbrx_tr, nbrx_vl, nbrx_tst = 100, 5, 20
        elif args.dataset == "glas":
            nbrx_tr, nbrx_vl, nbrx_tst = 20, 5, 20
        elif args.dataset == "bach-part-a-2018":
            nbrx_tr, nbrx_vl, nbrx_tst = 20, 5, 20
        elif args.dataset == "fgnet":
            nbrx_tr, nbrx_vl, nbrx_tst = 20, 5, 20
        elif args.dataset == "afad-lite":
            nbrx_tr, nbrx_vl, nbrx_tst = 200, 5000, 200
        elif args.dataset == "afad-full":
            nbrx_tr, nbrx_vl, nbrx_tst = 200, 5000, 200
def get_init_sup_samples(args,
                         sampler,
                         COMMON,
                         train_samples,
                         OUTD
                         ):
    """
    Get the initial full supervised data.
    :return:
    """
    previous_pairs = dict()
    previous_errors = False

    # drop normal samples and keep metastatic if: 1. dataset=CAM16. 2.
    # al_type != AL_WSL.
    cnd_drop_n = (args.dataset == constants.CAM16)
    cnd_drop_n &= (args.al_type != constants.AL_WSL)

    # round 0
    cnd = (args.al_type not in [constants.AL_FULL_SUP, constants.AL_WSL])
    cnd &= (args.al_it == 0)

    if  cnd:
        # deterministic function with respect to the original seed.
        set_default_seed()
        train_samples = sampler.sample_init_random_samples(train_samples)
        set_default_seed()
        # store on disc: remove the rootpath from files to be host-independent.
        # store relative paths not absolute.
        base_f = 'train_{}.csv'.format(args.al_it)
        al_outf = join(COMMON, base_f)
        csv_writer(clear_rootpath(train_samples, args),
                   al_outf
                   )
        shutil.copyfile(al_outf, join(OUTD, base_f))

    # round > 0: combine all the samples of the previous al rounds
    # and the selected samples for this round.
    cnd = (args.al_type not in [constants.AL_FULL_SUP, constants.AL_WSL])
    cnd &= (args.al_it > 0)
    if cnd:
        # 'train_{i}.csv' contains the selected samples at round i.
        lfiles = [join(
            COMMON, 'train_{}.csv'.format(t)) for t in range(args.al_it + 1)]

        if (args.al_type == constants.AL_LP) and (args.task == constants.SEG):
            # load previous pairs:
            # previous pairs are pairs that have been pseudo-labeled in the
            # previous al round. they are ready to be used as
            # pseudo-segmented samples. no statistical constraints will be
            # applied on them.
            fz = join(COMMON, 'train_pairs_{}.pkl'.format(args.al_it - 1))
            with open(fz, 'rb') as fp:
                previous_pairs = pkl.load(fp)

        train_samples = []
        rootpath = get_rootpath_2_dataset(args)
        for fx in lfiles:
            # load using the current host-root-path.
            train_samples.extend(csv_loader(fx,
                                            rootpath,
                                            drop_normal=cnd_drop_n
                                            )
                                 )

        # Force: set all the samples in train_samples to L.
        for tt in range(len(train_samples)):
            train_samples[tt][4] = constants.L

        # ============== block to delete =======================================
        # in the case we skipped previous rounds because we restart the
        # code, if we are in cc and use node, the paths will not match
        # since they are built upon the job id. so, we need to change it.
        if "CC_CLUSTER" in os.environ.keys():
            for i in range(len(train_samples)):
                front = os.sep.join(train_samples[i][1].split(os.sep)[:3])
                cnd = (front != os.environ["SLURM_TMPDIR"])
                if cnd:
                    # update the image input path
                    train_samples[i][1] = train_samples[i][1].replace(
                        front, os.environ["SLURM_TMPDIR"]
                    )

                    if args.task == constants.SEG:
                        # update the mask path
                        train_samples[i][2] = train_samples[i][2].replace(
                            front, os.environ["SLURM_TMPDIR"]
                        )

                    previous_errors = True

            # TODO: remove the above block. no longer necessary.
            # since we use relative paths in the node, we shouldn't have
            # mismatching paths when restarting the code.
            assert not previous_errors, "ERROR."
        # ======================================================================

        set_default_seed()
        for i in range(100):
            random.shuffle(train_samples)
        set_default_seed()

    return train_samples, previous_pairs, previous_errors
def compute_similarities(args,
                         tag_sims,
                         train_csv,
                         rootpath,
                         DEVICE,
                         SIMS,
                         training_log,
                         placement_node,
                         parent
                         ):
    """
    Compute similarities.
    :return:
    """
    # drop normal samples and keep metastatic if: 1. dataset=CAM16. 2.
    # al_type != AL_WSL.
    cnd_drop_n = (args.dataset == constants.CAM16)
    cnd_drop_n &= (args.al_type != constants.AL_WSL)

    if args.al_type != constants.AL_LP:  # get out.
        return 0

    # 1. compute sims
    current_dir = dirname(abspath(__file__))

    # compute proximity
    if not os.path.exists(
            join("pairwise_sims", '{}.tar.gz'.format(tag_sims))):
        announce_msg("Going to project samples, and compute apirwise "
                     "similarities")

        all_train_samples = csv_loader(train_csv,
                                       rootpath,
                                       drop_normal=cnd_drop_n
                                       )
        for ii, el in enumerate(all_train_samples):
            el[4] = constants.L  # just for the loader consistency.
            # masks are not used when computing the pairwise similarity.

        set_default_seed()
        compute_sim = PairwiseSimilarity(task=args.task)
        set_default_seed()
        t0 = dt.datetime.now()
        if args.task == constants.CL:
            set_default_seed()
            compute_sim(data=all_train_samples, args=args, device=DEVICE,
                        outd=SIMS)
            set_default_seed()
        elif args.task == constants.SEG:
            # it has to be done differently. the similarity is measured
            # only between samples within the same class.

            for k in args.name_classes.keys():
                samples_in_same_class = [
                    sx for sx in all_train_samples if sx[3] == k]
                print("Computing similarities for class {}:".format(k))
                set_default_seed()
                compute_sim(data=samples_in_same_class, args=args,
                            device=DEVICE, outd=SIMS, label=k)
                set_default_seed()

        msg = "Time to compute sims {}: {}".format(
            tag_sims, dt.datetime.now() - t0
        )
        print(msg)
        log(training_log, msg)

        # compress, move files.

        if "CC_CLUSTER" in os.environ.keys():  # if CC
            cmdx = "cd {} && " \
                   "cd .. && " \
                   "tar -cf {}.tar.gz {} && " \
                   "cp {}.tar.gz {} && " \
                   "cd {} ".format(
                    SIMS,
                    tag_sims,
                    tag_sims,
                    tag_sims,
                    join(current_dir, "pairwise_sims"),
                    current_dir
                    )
        else:
            cmdx = "cd {} && " \
                   "tar -cf {}.tar.gz {} && " \
                   "cd {} ".format(
                    "./pairwise_sims",
                    tag_sims,
                    tag_sims,
                    current_dir
                    )

        tt = dt.datetime.now()
        print("Running bash-cmds: \n{}".format(cmdx.replace("&& ", "\n")))
        subprocess.run(cmdx, shell=True, check=True)
        msg += "\n time to run the command {}: {}".format(
            cmdx, dt.datetime.now() - tt)
        print(msg)
        log(training_log, msg)

    else:  # unzip if necessary.
        cmdx = None
        if "CC_CLUSTER" in os.environ.keys():  # if CC, copy to node.
            pr = join(placement_node, parent, "pairwise_sims")
            folder = join(pr, tag_sims)
            uncomp = False
            if not os.path.exists(folder):
                uncomp = True
            else:
                if len(os.listdir(folder)) == 0:
                    uncomp = True
            if uncomp:
                cmdx = "cp {}/{}.tar.gz {} && " \
                       "cd {} && " \
                       "tar -xf {}.tar.gz && " \
                       "cd {} ".format(
                            "./pairwise_sims",
                            tag_sims,
                            pr,
                            pr,
                            tag_sims,
                            current_dir
                            )

        else:
            folder = join('./pairwise_sims', tag_sims)
            uncomp = False
            if not os.path.exists(folder):
                uncomp = True
            else:
                if len(os.listdir(folder)) == 0:
                    uncomp = True

            if uncomp:
                cmdx = "cd {} && " \
                       "tar -xf {}.tar.gz && " \
                       "cd {} ".format(
                            "./pairwise_sims",
                            tag_sims,
                            current_dir
                            )

        if cmdx is not None:
            tt = dt.datetime.now()
            print("Running bash-cmds: \n{}".format(cmdx.replace("&& ", "\n")))
            subprocess.run(cmdx, shell=True, check=True)
            msg = "runtime of ALL the bash-cmds: {}".format(
                dt.datetime.now() - tt)
            print(msg)
            log(training_log, msg)

    return 0
def final_validate(model,
                   dataloader,
                   criterion,
                   device,
                   dataset,
                   outd,
                   args,
                   log_file=None,
                   name_set="",
                   pseudo_labeled=False,
                   store_results=False,
                   apply_selection_tech=False
                   ):
    """
    Perform a final evaluation of a set.
    (images do not have the same size, so we can't stack them in one tensor).
    Validation samples may be large to fit all in the GPU at once.

    This is similar to validate() but it can operate on all types of
    datasets (train, valid, test). It selects only L. Also, it does some other
    operations related to drawing and final training tasks.

    :param outd: str, output directory of this dataset.
    :param name_set: str, name to indicate which set is being processed. e.g.:
           trainset, validset, testset.
    :param pseudo_labeled: bool. if true, the dataloader is loading the set
           of samples that has been pseudo-labeled. this is the case only for
           our method. if false, the samples are fully labeled. this allows to
          measure the performance over the pseudo-labeled samples separately.
    :param store_results: bool (default: False). if true, results (
           predictions) are stored. Useful for test set to store predictions.
           can be disabled for train/valid sets.
    :param apply_selection_tech: bool. if true, we compute stats that will be
           used for sample selection for the nex AL round. useful on the
           leftovers.
    """
    set_default_seed()

    outd_data = join(outd, "prediction")
    if not os.path.exists(outd_data):
        os.makedirs(outd_data)

    # TODO: rescale depending on the dataset.
    visualisor = VisualsePredSegmentation(height_tag=args.height_tag,
                                          show_tags=True,
                                          scale=0.5
                                          )

    model.eval()
    criterion.eval()
    metrics = Metrics(threshold=args.seg_threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    # to track stats.
    keys_perf = ["acc", "dice_idx"]
    losses_kz = ["total_loss", "cl_loss", "seg_l_loss", "seg_lp_loss"]
    per_sample_kz = ["pred_labels", "pred_masks", "ids"]
    keys = keys_perf + losses_kz + per_sample_kz
    tracker = dict()
    for k in keys:
        if k not in per_sample_kz:
            tracker[k] = 0.
        else:
            tracker[k] = []

    n_samples = 0.
    n_sam_dice = 0.
    # ignoring samples is on only: 1. cam16 dataset. 2. wsl method.
    # for the rest of the methods, normal samples are completely dropped before
    # starting training.
    cnd_dice = (args.dataset == constants.CAM16)
    cnd_dice &= (args.al_type == constants.AL_WSL)

    # Selection criteria.
    cp_entropy = Entropy()  # to copmute entropy.
    entropy = []
    mcdropout_var = []

    t0 = dt.datetime.now()

    # conditioning
    entropy_cnd = (args.al_type == constants.AL_ENTROPY)
    entropy_cnd = entropy_cnd or ((args.al_type == constants.AL_LP) and (
            args.clustering == constants.CLUSTER_ENTROPY))
    mcdropout_cnd = (args.al_type == constants.AL_MCDROPOUT)

    # cnd_asrt = (args.task == constants.SEG)
    cnd_asrt = store_results
    cnd_asrt |= apply_selection_tech

    with torch.no_grad():
        for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            targets = None
            bsz = data.size()[0]

            if cnd_asrt:
                msg = "batchsize must be 1 for segmentation task. " \
                      "found {}.".format(bsz)
                assert bsz == 1, msg

            data = data.to(device)
            labels = label.to(device)
            masks_trg = mask.to(device)

            scores, masks_pred, maps = model(x=data, seed=None)

            # change to the predicted mask to be the maps if WSL.
            if args.al_type == constants.AL_WSL:
                lb_ = labels.view(-1, 1, 1, 1)
                masks_pred = (maps.argmax(dim=1, keepdim=True) == lb_).float()

            # resize pred mask if sizes mismatch.
            if masks_pred.shape != masks_trg.shape:
                _, _, oh, ow = masks_trg.shape
                masks_pred = dataset.turnback_mask_tensor_into_original_size(
                    pred_masks=masks_pred,
                    oh=oh,
                    ow=ow
                )

            if args.al_type != constants.AL_WSL:
                # TODO: change normalization location. (future)
                masks_pred = torch.sigmoid(masks_pred)  # binary segmentation

            losses_ = criterion(scores=scores,
                                labels=labels,
                                targets=targets,
                                masks_pred=masks_pred.view(bsz, -1),
                                masks_trg=masks_trg.view(bsz, -1),
                                tags=tag,
                                weights=None,
                                avg=False
                                )

            # metrics
            ignore_dice = None
            if cnd_dice:
                # ignore samples with label 'normal' (0) when computing dice.
                ignore_dice = (labels == 0).float().view(-1)

            metrx = metrics(scores=scores,
                            labels=labels,
                            tr_loss=criterion,
                            masks_pred=masks_pred.view(bsz, -1),
                            masks_trg=masks_trg.view(bsz, -1),
                            avg=False,
                            ignore_dice=ignore_dice
                            )
            tracker["acc"] += metrx[0].item()
            tracker["dice_idx"] += metrx[1].item()

            for j, los in enumerate(losses_):
                tracker[losses_kz[j]] += los.item()

            # stored things
            # always store ids.
            tracker["ids"].extend(ids)

            if store_results:
                tracker["pred_labels"].extend(
                    scores.argmax(dim=1, keepdim=False).cpu().numpy().tolist()
                )
                # store binary masks
                tracker["pred_masks"].extend(
                    [(metrics.get_binary_mask(
                        masks_pred[kk]).detach().cpu().numpy() > 0.) for kk in
                     range(bsz)]
                )

            n_samples += bsz
            if cnd_dice:
                n_sam_dice += ignore_dice.numel() - ignore_dice.sum()
            else:
                n_sam_dice += bsz

            # ==================================================================
            #                    START: COMPUTE INFO. FOR SELECTION CRITERIA.
            # ==================================================================
            # 1. Entropy
            if entropy_cnd and apply_selection_tech:
                if args.task == constants.CL:
                    entropy.extend(
                        cp_entropy(
                            F.softmax(scores, dim=1)).cpu().numpy().tolist()
                    )
                elif args.task == constants.SEG:
                    assert bsz == 1, "batchsize must be 1. " \
                                     "found {}.".format(bsz)
                    # nbr_pixels, 2 (forg, backg).
                    pixel_probs = torch.cat((masks_pred.view(-1, 1),
                                             masks_pred.view(-1, 1)), dim=1)
                    avg_pixel_entropy = cp_entropy(pixel_probs).mean()
                    entropy.append(avg_pixel_entropy.cpu().item())
                else:
                    raise ValueError("Unknown task {}.".format(args.task))

            # 1. MC-Dropout
            if mcdropout_cnd and apply_selection_tech:

                reset_seed(int(os.environ["MYSEED"]))

                if args.task == constants.CL:
                    # todo
                    raise NotImplementedError
                elif args.task == constants.SEG:
                    assert bsz == 1, "batchsize must be 1. " \
                                     "found {}.".format(bsz)
                    stacked_masks = None
                    # turn on dropout
                    model.set_dropout_to_train_mode()
                    for it_mc in range(args.mcdropout_t):
                        scores, masks_pred, maps = model(x=data, seed=None)

                        # resize pred mask if sizes mismatch.
                        if masks_pred.shape != masks_trg.shape:
                            _, _, oh, ow = masks_trg.shape
                            masks_pred = \
                                dataset.turnback_mask_tensor_into_original_size(
                                    pred_masks=masks_pred, oh=oh, ow=ow
                                )

                        # TODO: change normalization location. (future)
                        masks_pred = torch.sigmoid(masks_pred)
                        # stack flatten masks horizontally.
                        if stacked_masks is None:
                            stacked_masks = masks_pred.view(-1, 1)
                        else:
                            stacked_masks = torch.cat(
                                (stacked_masks, masks_pred.view(-1, 1)),
                                dim=0
                            )
                    # compute variance per pixel, then average over the
                    # image. images have different sizes. if it is not the
                    # case, it is fine since we divide all sum_var with a
                    # constant.
                    variance = stacked_masks.var(
                        dim=0, unbiased=True).mean()
                    mcdropout_var.append(variance.cpu().item())

                    # turn off dropout
                    model.set_dropout_to_eval_mode()
                else:
                    raise ValueError(
                        "Unknown task {}.".format(args.task))
            # ==================================================================
            #                    END: COMPUTE INFO. FOR SELECTION CRITERIA.
            # ==================================================================

            # Visualize
            cnd = (args.task == constants.SEG)
            cnd &= (args.subtask == constants.SUBCLSEG)
            cnd &= store_results

            # todo: separate storing stats from plotting figure + storing it.
            # todo: add plot_figure var.
            if cnd:
                output_file = join(outd_data, "{}.jpeg".format(ids[0]))

                visualisor(
                    img_in=dataset.get_original_input_img(i),
                    mask_pred=metrics.get_binary_mask(
                        masks_pred).detach().cpu().squeeze().numpy(),
                    true_label=dataset.get_original_input_label_int(i),
                    label_pred=scores.argmax(dim=1, keepdim=False).cpu().numpy(
                    ).tolist()[0],
                    id_sample=ids[0],
                    name_classes=dataset.name_classes,
                    true_mask=masks_trg.cpu().squeeze().numpy(),
                    dice=metrx[1].item(),
                    output_file=output_file,
                    scale=None,
                    binarize_pred_mask=False,
                    cont_pred_msk=masks_pred.detach().cpu().squeeze().numpy()
                )

            # clear memory
            del scores
            del masks_pred
            del maps

    # avg: acc, dice_idx
    for k in keys_perf:
        if (k == 'dice_idx') and cnd_dice:
            announce_msg("Ignore normal samples mode. "
                         "Total samples left: {}/Total: {}.".format(
                n_sam_dice, n_samples))
            if n_sam_dice != 0:
                tracker[k] /= float(n_sam_dice)
            else:
                tracker[k] = 0.

        else:
            tracker[k] /= float(n_samples)

    # compress, then delete files to prevent overloading the disc quota of
    # number of files.
    def compress_del(src):
        """
        Compress a folder with name 'src' into 'src.zip' or 'src.tar.gz'.
        Then, delete the folder 'src'.
        :param src: str, absolute path to the folder to compress.
        :return:
        """
        try:
            cmd_compress = 'zip -rjq {}.zip {}'.format(src, src)
            print("Run: `{}`".format(cmd_compress))
            subprocess.run(cmd_compress, shell=True, check=True)
        except subprocess.CalledProcessError:
            cmd_compress = 'tar -zcf {}.tar.gz -C {} .'.format(src, src)
            print("Run: `{}`".format(cmd_compress))
            subprocess.run(cmd_compress, shell=True, check=True)

        cmd_del = 'rm -r {}'.format(src)
        print("Run: `{}`".format(cmd_del))
        os.system(cmd_del)

    compress_del(outd_data)

    to_write = "EVAL.FINAL {} -- pseudo-labeled {}: ACC: {:.3f}%," \
               " DICE: {:.3f}%, time:{}".format(
                name_set, pseudo_labeled, tracker["acc"] * 100.,
                tracker["dice_idx"] * 100., dt.datetime.now() - t0
               )
    to_write = "{} \n{} \n{}".format(10 * "=", to_write, 10 * "=")
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # store the stats in pickle.
    final_stats = dict()  # for a quick access: perf.
    for k in keys_perf:
        final_stats[k] = tracker[k] * 100.

    pred_stats = dict()  # it is heavy.
    for k in losses_kz + per_sample_kz:
        pred_stats[k] = tracker[k]

    pred_stats["pseudo-labeled"] = pseudo_labeled

    if pseudo_labeled:
        outfile_tracker = join(outd, 'final-tracker-{}-pseudoL-{}.pkl'.format(
            name_set, pseudo_labeled))
        outfile_pred = join(outd, 'final-pred-{}-pseudoL-{}.pkl'.format(
            name_set, pseudo_labeled))
    else:
        outfile_tracker = join(outd, 'final-tracker-{}.pkl'.format(name_set))
        outfile_pred = join(outd, 'final-pred-{}.pkl'.format(name_set))

    with open(outfile_tracker, 'wb') as fout:
        pkl.dump(final_stats, fout, protocol=pkl.HIGHEST_PROTOCOL)
    with open(outfile_pred, 'wb') as fout:
        pkl.dump(pred_stats, fout, protocol=pkl.HIGHEST_PROTOCOL)

    # store info. for selection techniques.
    # 1. Entropy.
    if entropy_cnd and apply_selection_tech:
        with open(join(outd, 'entropy-{}.pkl'.format(name_set)), 'wb') as fout:
            pkl.dump({'entropy': entropy,
                      'ids': tracker["ids"]}, fout,
                     protocol=pkl.HIGHEST_PROTOCOL)

    # 2. MC-Dropout.
    if mcdropout_cnd and apply_selection_tech:
        with open(join(outd, 'mc-dropout-{}.pkl'.format(name_set)),
                  'wb') as fout:
            pkl.dump({'mc-dropout-var': mcdropout_var,
                      'ids': tracker["ids"]}, fout,
                     protocol=pkl.HIGHEST_PROTOCOL)

    set_default_seed()
def _pseudo_annotate(model,
                     dataset,
                     dataloader,
                     criterion,
                     device,
                     pairs,
                     metrics_fd,
                     fd_p_msks,
                     args,
                     SHARED_OPT_MASKS,
                     threshold=0.5
                     ):
    """
    Perform a validation over a set.
    Allows storing pseudo-masks.


    Dataset passed here must be in `validation` mode.
    Task: SEG.
    """
    set_default_seed()

    model.eval()
    metrics = Metrics(threshold=threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    # to track stats.
    sum_dice = 0.
    n_samples = 0.

    with torch.no_grad():
        for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            targets = None  # not needed for SEG task.
            masks_trg = None
            bsz = data.size()[0]
            assert bsz == 1, "batch size must be 1. found {}. ".format(bsz)

            id_u = ids[0]
            id_l = pairs[id_u]
            name_file = "{}-{}".format(id_u, id_l)
            path_to_save_mask = join(fd_p_msks, "{}.bmp".format(name_file))
            path_to_save_metric = join(metrics_fd, "{}.pkl".format(name_file))

            data = data.to(device)
            labels = label.to(device)
            if mask is not None:
                masks_trg = mask.to(device)

            scores, masks_pred, maps = model(x=data, seed=None)

            # resize pred mask if sizes mismatch.
            if masks_pred.shape != masks_trg.shape:
                _, _, oh, ow = masks_trg.shape
                masks_pred = dataset.turnback_mask_tensor_into_original_size(
                    pred_masks=masks_pred, oh=oh, ow=ow
                )
            # TODO: change normalization location. (future)
            masks_pred = torch.sigmoid(masks_pred)

            metrx = metrics(scores=scores,
                            labels=labels,
                            tr_loss=criterion,
                            masks_pred=masks_pred.view(bsz, -1),
                            masks_trg=masks_trg.view(bsz, -1),
                            avg=False
                            )

            sum_dice += metrx[1].item()

            n_samples += bsz  # nbr samples.

            # store the mask and metric. =======================================
            bin_mask = metrics.get_binary_mask(masks_pred, threshold=threshold)
            bin_mask = bin_mask.detach().cpu().squeeze().numpy()
            # issue with mode=1...
            # https://stackoverflow.com/questions/32159076/python-pil-bitmap-png-
            # from-array-with-mode-1
            img_mask = Image.fromarray(bin_mask.astype(np.uint8) * 255,
                                       mode='L').convert('1')

            img_mask.save(path_to_save_mask)
            with open(path_to_save_metric, "wb") as fout:
                pkl.dump(
                    {"dice_u": metrx[1].item()}, fout,
                    protocol=pkl.HIGHEST_PROTOCOL)

            # ==================================================================


            # clear the memory
            del scores
            del masks_pred

    # clean: remove duplicates.
    clean_shared_masks(args=args,
                       SHARED_OPT_MASKS=SHARED_OPT_MASKS,
                       metrics_fd=metrics_fd,
                       all_pairs=pairs
                       )

    set_default_seed()

    return sum_dice, n_samples
def pseudo_annotate_some_pairs(model,
                               some_pairs,
                               samples,
                               args,
                               criterion,
                               transform_tensor,
                               padding_size_eval,
                               device,
                               fd_p_msks,
                               SHARED_OPT_MASKS,
                               threshold=0.5,
                               results_log=None,
                               txt=""
                               ):
    """
    Pseudo-annotate some pairs.
    :param some_pairs:
    :param samples:
    :param threshold:
    :return:
    """
    set_default_seed()

    metrics_fd = join(fd_p_msks, "metrics")
    # create dataset.
    evalset, eval_loader  = get_validationset(args,
                                              samples,
                                              transform_tensor,
                                              padding_size_eval,
                                              batch_size=1
                                              )

    set_default_seed()
    sum_dice, n_samples = _pseudo_annotate(model=model,
                                           dataset=evalset,
                                           dataloader=eval_loader,
                                           criterion=criterion,
                                           device=device,
                                           pairs=some_pairs,
                                           metrics_fd=metrics_fd,
                                           fd_p_msks=fd_p_msks,
                                           args=args,
                                           SHARED_OPT_MASKS=SHARED_OPT_MASKS,
                                           threshold=threshold
                                           )



    if results_log is not None:
        avg_dice = (sum_dice / float(n_samples)) if n_samples != 0 else 0.
        msg = "{}: {} [scale_seg_u: {}] ".format(
            txt,
            avg_dice * 100.,
            args.scale_seg_u
            )
        log(results_log, msg)
        print(msg)


    out = {
        "sum_dice": sum_dice,
        "n_samples": n_samples
    }

    return deepcopy(out)
def estimate_best_seg_thres(model,
                            dataset,
                            dataloader,
                            criterion,
                            device,
                            args,
                            epoch=0,
                            cycle=0,
                            log_file=None
                            ):
    """
    Perform a validation over a set.
    Allows estimating a segmentation threshold based on this evaluation.

    This is intended fo the `validationset` where we can estimate the best
    threshold since the samples are pixel-wise labeled.

    We specify a set of theresholds, and pick the one that has the best mean
    IOU score. MIOU is better then Dice index.

    Dataset passed here must be in `validation` mode.
    Task: SEG.
    """
    msg = "Can't use/estimate threshold over AL_WSL. masks are already binary."
    assert args.al_type != constants.AL_WSL, msg

    set_default_seed()

    announce_msg("Estimating threshold on validation set.")
    set_default_seed()

    model.eval()
    # the specified threshold here does not matter. we will use a different
    # ones later.
    metrics = Metrics(threshold=args.seg_threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    l_thress = np.arange(start=0.05, stop=1., step=0.01)
    # to track stats.
    avg_dice = np.zeros(l_thress.shape)
    avg_miou = np.zeros(l_thress.shape)

    nbr_ths = l_thress.size

    n_samples = 0.
    n_sam_dice = 0.
    # ignoring samples is on only: 1. cam16 dataset. 2. wsl method.
    # for the rest of the methods, normal samples are completely dropped before
    # starting training.
    cnd_dice = (args.dataset == constants.CAM16)
    cnd_dice &= (args.al_type == constants.AL_WSL)

    t0 = dt.datetime.now()

    with torch.no_grad():
        for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            targets = None  # not needed for SEG task.
            masks_trg = None
            bsz = data.size()[0]
            # assert bsz == 1, "batch size must be 1. found {}. ".format(bsz)

            data = data.to(device)
            labels = label.to(device)
            if mask is not None:
                masks_trg = mask.to(device)

            scores, masks_pred, maps = model(x=data, seed=None)

            # resize pred mask if sizes mismatch.
            if masks_pred.shape != masks_trg.shape:
                _, _, oh, ow = masks_trg.shape
                masks_pred = dataset.turnback_mask_tensor_into_original_size(
                    pred_masks=masks_pred, oh=oh, ow=ow
                )
            # TODO: change normalization location. (future)
            masks_pred = torch.sigmoid(masks_pred)

            ignore_dice = None
            if cnd_dice:
                # ignore samples with label 'normal' (0) when computing dice.
                ignore_dice = (labels == 0).float().view(-1)

            for ii, tt in enumerate(l_thress):
                metrx = metrics(scores=scores,
                                labels=labels,
                                tr_loss=criterion,
                                masks_pred=masks_pred.view(bsz, -1),
                                masks_trg=masks_trg.view(bsz, -1),
                                avg=False,
                                threshold=tt,
                                ignore_dice=ignore_dice
                                )

                avg_dice[ii] += metrx[1].item()
                avg_miou[ii] += metrx[2].item()

            n_samples += bsz  # nbr samples.

            if cnd_dice:
                n_sam_dice += ignore_dice.numel() - ignore_dice.sum()
            else:
                n_sam_dice += bsz

            # clear the memory
            del scores
            del masks_pred

    idx = avg_miou.argmax()
    best_threshold = l_thress[idx]
    if cnd_dice and (n_sam_dice != 0):
        best_dice = avg_dice[idx] / float(n_sam_dice)
        best_miou = avg_miou[idx] / float(n_sam_dice)
    else:
        best_dice = avg_dice[idx] / float(n_samples)
        best_miou = avg_miou[idx] / float(n_samples)

    out = {
        "best_threshold": best_threshold,
        "best_dice": best_dice,
        "best_miou": best_miou
    }

    to_write = "VL [ESTIM-THRESH] {:>2d}-{:>2d}. BEST-DICE: {:.2f}, " \
               "BEST-MIOU: {:.2f}. " \
               "time:{}".format(cycle,
                                epoch,
                                best_dice * 100.,
                                best_miou * 100.,
                                dt.datetime.now() - t0
    )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    to_write =  "best threshold: {:.3f}. " \
                 "BEST-DICE-VL: {:.2f}%, " \
                 "BEST-MIOU-VL: {:.2f}%.".format(
        best_threshold, best_dice * 100., best_miou * 100.
    )
    announce_msg(to_write)
    if log_file:
        log(log_file, to_write)

    set_default_seed()

    return out
def validate(model,
             dataset,
             dataloader,
             criterion,
             device,
             stats,
             args,
             epoch=0,
             cycle=0,
             log_file=None
             ):
    """
    Perform a validation over a set.
    Dataset passed here must be in `validation` mode.
    Task: SEG.
    """
    set_default_seed()

    model.eval()
    criterion.eval()
    metrics = Metrics(threshold=args.seg_threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    # to track stats.
    keys = ["acc", "dice_idx"]
    losses_kz = ["total_loss", "cl_loss", "seg_l_loss", "seg_lp_loss"]
    keys = keys + losses_kz
    tracker = dict()
    for k in keys:
        tracker[k] = 0.

    n_samples = 0.
    n_sam_dice = 0.

    # ignoring samples is on only: 1. cam16 dataset. 2. wsl method.
    # for the rest of the methods, normal samples are completely dropped before
    # starting training.
    cnd_dice = (args.dataset == constants.CAM16)
    cnd_dice &= (args.al_type == constants.AL_WSL)

    t0 = dt.datetime.now()

    with torch.no_grad():
        for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            targets = None  # not needed for SEG task.
            masks_trg = None
            bsz = data.size()[0]

            data = data.to(device)
            labels = label.to(device)
            if mask is not None:
                masks_trg = mask.to(device)

            scores, masks_pred, maps = model(x=data, seed=None)

            check_nans(maps, "fresh-maps")

            # change to the predicted mask to be the maps if WSL.
            if args.al_type == constants.AL_WSL:
                lb_ = labels.view(-1, 1, 1, 1)
                masks_pred = (maps.argmax(dim=1, keepdim=True) == lb_).float()


            # resize pred mask if sizes mismatch.
            if masks_pred.shape != masks_trg.shape:
                _, _, oh, ow = masks_trg.shape
                masks_pred = dataset.turnback_mask_tensor_into_original_size(
                    pred_masks=masks_pred, oh=oh, ow=ow
                )
                check_nans(masks_pred, "mask-pred-back-to-normal-size")

            if args.al_type != constants.AL_WSL:
                # TODO: change normalization location. (future)
                masks_pred = torch.sigmoid(masks_pred)  # binary segmentation

            losses = criterion(scores=scores,
                               labels=labels,
                               targets=targets,
                               masks_pred=masks_pred.view(bsz, -1),
                               masks_trg=masks_trg.view(bsz, -1),
                               tags=tag,
                               weights=None,
                               avg=False
                               )

            # metrics
            ignore_dice = None
            if cnd_dice:
                # ignore samples with label 'normal' (0) when computing dice.
                ignore_dice = (labels == 0).float().view(-1)

            metrx = metrics(scores=scores,
                            labels=labels,
                            tr_loss=criterion,
                            masks_pred=masks_pred.view(bsz, -1),
                            masks_trg=masks_trg.view(bsz, -1),
                            avg=False,
                            ignore_dice=ignore_dice
                            )

            tracker["acc"] += metrx[0].item()
            tracker["dice_idx"] += metrx[1].item()

            # print("Dice valid i: {} {}. datashape {} id {}".format(
            #     i, metrx[1].item() * 100., data.shape, ids))


            for j, los in enumerate(losses):
                tracker[losses_kz[j]] += los.item()

            n_samples += bsz  # nbr samples.

            if cnd_dice:
                n_sam_dice += ignore_dice.numel() - ignore_dice.sum()
            else:
                n_sam_dice += bsz

            # clear the memory
            del scores
            del masks_pred
            del maps

    # average.
    for k in keys:
        if (k == 'dice_idx') and cnd_dice:
            announce_msg("Ignore normal samples mode. "
                         "Total samples left: {}/Total: {}.".format(
                n_sam_dice, n_samples))
            if n_sam_dice == 0:
                tracker[k] = 0.
            else:
                tracker[k] /= float(n_sam_dice)
        else:
            tracker[k] /= float(n_samples)

    to_write = "VL {:>2d}-{:>2d}: ACC: {:.4f}, DICE: {:.4f}, time:{}".format(
                cycle, epoch, tracker["acc"] * 100., tracker["dice_idx"] * 100.,
                dt.datetime.now() - t0
                 )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # Update stats.
    if stats is not None:
        for k in keys:
            stats[k].append(tracker[k])
    else:
        # convert each element into a list
        for k in keys:
            tracker[k] = [tracker[k]]
        stats = deepcopy(tracker)

    set_default_seed()

    return stats
    def __call__(self, data, args, device, outd, label=None):
        """
        Compute the pairwise distance.

        :param data: list of str-path to samples. The representation of each
        sample will be computed using a projector. its config. is in args.
        :param args: object containing the the main file input arguments.
        :param device: device on where the computation will take place.
        :param outd path where we write the similarities.
        :param label: str or None. used only in the case of self.task is SEG.
        """
        already_done = False
        if already_done:  # if we already computed this sim. nothing to do.
            return 0

        histc = None
        epsilon = 1e-8
        if args.use_dist_global_hist:
            histc = SoftHistogram(bins=args.nbr_bins_histc,
                                  min=args.min_histc,
                                  max=args.max_histc,
                                  sigma=args.sigma_histc).to(device)
        # dataloader
        transform_tensor = get_transforms_tensor(args)
        set_default_seed()
        dataset = PhotoDataset(
            data,
            args.dataset,
            args.name_classes,
            transforms.Compose([transforms.ToTensor()]),
            set_for_eval=False,
            transform_img=None,
            resize=None,
            crop_size=None,
            padding_size=(None, None),
            padding_mode=args.padding_mode,
            up_scale_small_dim_to=None,
            do_not_save_samples=True,
            ratio_scale_patch=1.,
            for_eval_flag=True,
            scale_algo=args.scale_algo,
            resize_h_to=args.resize_h_to_opt_mask,
            resize_mask=args.resize_mask_opt_mask,  # not important.
            enhance_color=args.enhance_color,
            enhance_color_fact=args.enhance_color_fact)
        set_default_seed()
        data_loader = DataLoader(dataset,
                                 batch_size=args.pair_w_batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 collate_fn=default_collate,
                                 worker_init_fn=_init_fn)
        set_default_seed()

        gaussian_smoother = None

        # loop! on GPU.
        nbr_samples = len(dataset)
        nbr_batches = len(data_loader)
        acc_label_prop = 0.
        z = 0.
        # project  all data and store them on disc in batches.
        idss = []
        labelss = []
        list_projections = []
        tag = ""
        # for the task SEG, the tag is helpful to avoid mixing files.
        if self.task == constants.SEG:
            tag = "_{}_{}".format(self.task, label)

        for j, (ids, imgs, masks, labels, tags, _) in enumerate(data_loader):
            with torch.no_grad():
                imgs = imgs.to(device)

                # 2. compute the histograms for matching.
                if args.use_dist_global_hist:
                    nbrs, c, h, w = imgs.shape

                    if args.smooth_img:
                        if gaussian_smoother is None:
                            gaussian_smoother = GaussianSmoothing(
                                channels=c,
                                kernel_size=args.smooth_img_ksz,
                                sigma=args.smooth_img_sigma,
                                dim=2,
                                exact_conv=True,
                                padding_mode='reflect').to(device)
                        # smooth the image.
                        imgs = gaussian_smoother(imgs)

                    re_imgs = imgs.view(nbrs * c, h * w)
                    hists_j = histc(re_imgs)  # nbrs * c, nbr_bins
                    # normalize to prob. dist
                    hists_j = hists_j + epsilon
                    hists_j = hists_j / hists_j.sum(dim=-1).unsqueeze(1)
                    hists_j = hists_j.view(nbrs, c, -1).cpu()

                    with open(join(outd, "histj_{}{}.pkl".format(j, tag)),
                              "wb") as fhist:
                        pkl.dump(hists_j, fhist, protocol=pkl.HIGHEST_PROTOCOL)

                # store some stuff.
                idss.extend(ids)
                labelss.extend(labels.numpy().tolist())
                # can't fit in memory
                # list_projections.append(copy.deepcopy(pr_j))

        # list_projections = [pr.to(device) for pr in list_projections]
        # compute sim.
        for i in tqdm.tqdm(range(nbr_samples), ncols=80, total=nbr_samples):
            id_sr, img_sr, mask_sr, label_sr, tag_sr, _ = dataset[i]
            if img_sr.ndim == 2:
                img_sr = img_sr.view(1, 1, img_sr.size()[0], img_sr.size()[1])
            elif img_sr.ndim == 3:
                img_sr = img_sr.view(1,
                                     img_sr.size()[0],
                                     img_sr.size()[1],
                                     img_sr.size()[2])
            else:
                raise ValueError('Unexpected dim: {}.'.format(img_sr.ndim))

            img_sr = img_sr.to(device)
            # histo
            histo_trg = None
            if args.use_dist_global_hist:
                if args.smooth_img:
                    img_sr = gaussian_smoother(img_sr)

                nbrs, c, h, w = img_sr.shape  # only one image.-> nbrs=1
                histo_trg = histc(img_sr.view(nbrs * c, h * w))  # c,
                # nbrbins.
                # normalize to prob. dist
                histo_trg = histo_trg + epsilon
                histo_trg = histo_trg / histo_trg.sum(dim=-1).unsqueeze(1)

            dists = None
            histo_prox = None
            for j in range(nbr_batches):
                with torch.no_grad():

                    # 2. histo proximity =======================================
                    if args.use_dist_global_hist:
                        with open(join(outd, "histj_{}{}.pkl".format(j, tag)),
                                  "rb") as fhisto:
                            hists_j = pkl.load(fhisto).to(device)  # bsize,
                            # c, nbrbins.

                        bs_sr, c_sr, nbr_bn_sr = hists_j.shape
                        tmp = self.hist_prox(
                            trg_his=histo_trg.repeat(bs_sr, 1),
                            src_his=hists_j.view(bs_sr * c_sr, -1))  # =>
                        # bs_sr * sr_c.
                        tmp = tmp.view(bs_sr, c_sr)

                        # tmp = self.sim(x, pr_j.to(device))
                        if tmp.ndim == 0:  # case of grey images with batch
                            # size of 1.
                            tmp = tmp.view(1, 1)

                        if histo_prox is None:
                            histo_prox = tmp
                        else:
                            histo_prox = torch.cat((histo_prox, tmp), dim=0)

            proximity = None
            if dists is not None:
                dists = dists.squeeze()  # remove the 1 dim. it happens when
                # batch_size == 1.
                dists = dists.cpu()
                proximity = dists.view(-1, 1)

            if histo_prox is not None:
                histo_prox = histo_prox.cpu()
            # shapes: dists: n. histo_prox: n, c where c is the number of
            # plans in the images.

            if args.use_dist_global_hist:
                # proximity = [l2 dist, r, g, b] or [l2 dist, grey]
                if proximity is not None:
                    proximity = torch.cat((proximity, histo_prox), dim=1)
                else:
                    proximity = histo_prox

            z += proximity.sum(dim=0)
            # store sims.
            srt, idx = torch.sort(proximity.sum(dim=1).squeeze(),
                                  descending=False)

            msg = "ERROR: {}".format(proximity[idx[0]].sum())
            # floating point issue: 1.1920928955078125e-07.
            # assert proximity[idx[0]].sum() == 0., msg

            label_pred = labelss[idx[1]]  # take the second because the first
            # is 0.
            # it is ok to overload the disc to avoid runtime cost.
            stats = {
                'id_sr': id_sr,  # id source
                'label_sr': label_sr,  # label source
                'label_pred': label_pred,
                'nearest_id': idss[idx[1]],  # closest sample.
                'proximity': proximity,
                'index_sort': idx  # so we do not have to sort again. [ok]
            }
            # name of the file: id_idNearest. this allows to get the id of
            # the nearest sample without reading the file. this speeds up the
            # pairing by avoiding disc access.
            id_nearest = stats['nearest_id']

            torch.save(proximity, join(outd, '{}.pt'.format(id_sr)))
            acc_label_prop += (label_sr == label_pred) * 1.

            if args.task == constants.SEG:
                msg = 'for weakly.sup.seg, all samples of the data provided' \
                      'to this function must have the same label. it does ' \
                      'not seem the case.W'
                assert label_sr == label_pred, msg

        # Cleaning.
        for j in range(nbr_batches):
            path1 = join(outd, "histj_{}{}.pkl".format(j, tag))
            path2 = join(outd, "histj_{}{}.pkl".format(j, tag))
            for path in [path1, path2]:
                if os.path.isfile(path):
                    os.remove(path)

        # store accuracy: the upper bound perf (when every sample is labeled
        # except one). this is useful only for classification task only.
        shared_stats = {
            'idss': idss,
            'labelss': labelss,
            'acc': 100. * acc_label_prop / nbr_samples,
            'z': z.cpu()
        }
        with open(join(outd, 'shared-stats{}.pkl'.format(tag)), 'wb') as fout:
            pkl.dump(shared_stats, fout, protocol=pkl.HIGHEST_PROTOCOL)

        if args.task == constants.SEG:
            msg = 'for weakly.sup.seg, accuracy is expected to be 100%. but' \
                  'found {}'.format(shared_stats['acc'])
            assert shared_stats['acc'] == 100., msg

        announce_msg('Upper bound classification accuracy: {}%'.format(
            shared_stats['acc']))
        announce_msg('Z: {}'.format(z))