def test__LBPModule():
    """
    Test: _LBPModule()
    :return:
    """
    reset_seed(0, check_cudnn=False)

    cuda = "0"
    print("cuda:{}".format(cuda))
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")

    instance = _LBPModule(kernel_sizes=[3, 5, 7],
                          exact_conv=True,
                          normalize=True)
    instance.to(DEVICE)
    bs, c, h, w = 1, 5, 20, 30

    x = torch.rand(bs, c, h, w)
    x = x.to(DEVICE)
    out = instance(x)
    for p in instance.parameters():  # expected: no parameters.
        print("p: ", p.shape)
    print(out.shape)
    print(out.min(), out.max())
コード例 #2
0
    def forward(self, x, seed=None, prngs_cuda=None):
        """
        Input:
            In the case of K classes:
                x: torch tensor of size (n, c, h, w), where n is the batch
                size, c is the number of classes,
                h is the height of the feature map, w is its width.
            seed: int, seed for the thread to guarantee reproducibility over a
            fixed number of gpus.
        Output:
            scores: torch vector of size (k). Contains the wildcat score of
            each class. A score is a linear combination
            of different features. The class with the highest features is the
            winner.
        """
        b, c, h, w = x.shape
        activations = x.view(b, c, h * w)

        n = h * w

        sorted_features = torch.sort(activations, dim=-1, descending=True)[0]
        kmax = self.get_k(self.kmax, n)
        kmin = self.get_k(self.kmin, n)

        # assert kmin != 0, "kmin=0"
        assert kmax != 0, "kmax=0"

        # dropout
        if self.dropout != 0.:
            if seed is not None:
                thread_lock.acquire()
                msg = "`prngs_cuda` is expected to not be None. Exiting " \
                      ".... [NOT OK]"
                assert prngs_cuda is not None, msg
                prng_state = (torch.cuda.get_rng_state().cpu())
                reproducibility.reset_seed(seed)
                torch.cuda.set_rng_state(prngs_cuda.cpu())

                # instruction that causes randomness.
                sorted_features = self.dropout_md(sorted_features)

                reproducibility.reset_seed(seed)
                torch.cuda.set_rng_state(prng_state)
                thread_lock.release()
            else:
                sorted_features = self.dropout_md(sorted_features)

        scores = sorted_features.narrow(-1, 0, kmax).sum(-1).div_(kmax)

        if kmin > 0 and self.alpha != 0.:
            scores.add(
                sorted_features.narrow(-1, n - kmin, kmin).sum(-1).mul_(
                    self.alpha / kmin)).div_(2.)

        return scores
def test__HistogramOfGradientMagnitudesProp():
    """
    Test function: _HistogramOfGradientMagnitudesProp().
    """
    reset_seed(0)
    cuda = "0"
    print("cuda:{}".format(cuda))
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")

    bins = 256
    sigma = 1e5
    c, h, w = 4, 10, 20
    m = _HistogramOfGradientMagnitudesProp(c=c,
                                           bins=bins,
                                           sigma=sigma,
                                           convert_to_grey=False)
    m.to(DEVICE)

    x = torch.rand((c, h, w)).to(DEVICE)
    mask = torch.rand((1, h, w)).to(DEVICE)

    output = m(x, mask=mask)
    print(output.shape, output.dtype, output.device)
    print(output.min(), output.max())

    # test on an image.
    root = '../data/debug/input'
    file = 'Black_Footed_Albatross_0006_796065.jpg'
    path = join(root, file)
    input_img = Image.open(path).convert('RGB')
    totensor = transforms.Compose([transforms.ToTensor()])
    input_img = totensor(input_img).to(DEVICE)
    c, h, w = input_img.shape
    m = _HistogramOfGradientMagnitudesProp(c=c,
                                           bins=bins,
                                           sigma=sigma,
                                           convert_to_grey=True)
    m.to(DEVICE)
    mask = torch.rand((1, h, w)).to(DEVICE)
    print('test on an image of shape {}.'.format(input_img.shape))
    output = m(input_img, mask=mask)
    print(output.shape, output.dtype, output.device)
    print(output.min(), output.max())
    fig = plt.figure()
    x = list(range(256))
    for i in range(m.c):
        h = output[i].cpu().squeeze().numpy()
        plt.bar(x, h, label="{}".format(i))

    fig.savefig(
        join("../data/debug/visualization/", "hgm-c={}-{}".format(m.c, file)))
def test__SobelFilter2D():
    """
    Test: _SobelFilter2D()
    :return:
    """
    reset_seed(0, check_cudnn=False)

    cuda = "0"
    print("cuda:{}".format(cuda))
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")

    sobel = _SobelFilter2D(channels=3, exact_conv=True)
    sobel.to(DEVICE)

    input = torch.rand(1, 3, 100, 100).to(DEVICE)
    output = sobel(input)
    print(output.shape, input.shape)
    for p in sobel.parameters():
        print(p.shape, p.requires_grad)
def test_SoftHistogram():
    """
    Test function: SoftHistogram().
    """
    reset_seed(0)
    cuda = "0"
    print("cuda:{}".format(cuda))
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")

    bins = 256
    min = 0.
    max = 1.
    sigma = 1e5
    m = SoftHistogram(bins=bins, min=min, max=max, sigma=sigma)
    m.to(DEVICE)

    batch_sz = 3
    x = torch.rand((batch_sz, 5)).to(DEVICE)
    print(x[0])

    mask = torch.rand((batch_sz, 5)).to(DEVICE)
    print(mask[0])
    output = m(x, mask=x)
    print(output.shape, output.dtype, output.device)
    print(output.shape)
    better_hist = None
    for i in range(batch_sz):
        if i == 0:
            better_hist = torch.histc(x[i, :], bins=bins, min=min,
                                      max=max).view(1, -1)
        else:
            better_hist = torch.cat(
                (better_hist, torch.histc(x[i, :], bins=bins, min=min,
                                          max=max).view(1, -1)),
                dim=0)
    print(better_hist.shape)
    print("error: {}".format(torch.abs(output - better_hist).sum(dim=1)))
def test_GaussianSmoothing():
    """
    Test: GaussianSmoothing()
    :return:
    """
    reset_seed(0, check_cudnn=False)

    cuda = "0"
    print("cuda:{}".format(cuda))
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")

    smoothing = GaussianSmoothing(channels=3,
                                  kernel_size=3,
                                  sigma=1,
                                  exact_conv=True)
    smoothing.to(DEVICE)

    input = torch.rand(1, 3, 100, 100).to(DEVICE)
    output = smoothing(input)
    print(output.shape, input.shape)
    for p in smoothing.parameters():
        print(p.shape, p.requires_grad)
def test__HistogramProp():
    """
    Test function: _HistogramProp().
    """
    reset_seed(0)
    cuda = "0"
    print("cuda:{}".format(cuda))
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")

    bins = 256
    min = 0.
    max = 1.
    sigma = 1e5
    m = _HistogramProp(bins=bins, min=min, max=max, sigma=sigma)
    m.to(DEVICE)

    c, h, w = 4, 10, 20
    x = torch.rand((c, h, w)).to(DEVICE)

    mask = torch.rand((1, h, w)).to(DEVICE)

    output = m(x, mask=mask)
    print(output.shape, output.dtype, output.device)
def test__HighOrderMomentsProp():
    """
    Test function of _HighOrderMomentsProp().
    """
    reset_seed(0, check_cudnn=False)
    cuda = "0"
    print("cuda:{}".format(cuda))
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")

    c, h, w = 3, 20, 30
    x = torch.rand(c, h, w)
    mask = torch.rand(c, h, w)
    x = x.to(DEVICE)
    mask = mask.to(DEVICE)
    instance = _HighOrderMomentsProp(min_order_moment=2,
                                     max_order_moment=2,
                                     normalize_area_msk=True)

    t0 = dt.datetime.now()
    output = instance(x=x, mask=mask)
    print("forward time: {}".format(dt.datetime.now() - t0))
    print(output.shape, output.dtype, output.device)
    print(output)
    announce_msg("start training")
    tx0 = dt.datetime.now()

    set_default_seed()

    epoch = 0
    cycle = 0
    while epoch < args.max_epochs:
        # debug.

        # if (epoch == 0) and (cycle == 0):
        #     epoch = args.max_epochs - 1

        # reseeding tr/vl samples.
        reset_seed(seedj(epoch, 1, cycle, CONST1))
        trainset.set_up_new_seeds()
        reset_seed(seedj(epoch, 2, cycle, CONST1))

        tr_stats= train_one_epoch(model,
                                  optimizer,
                                  train_loader,
                                  CRITERION,
                                  DEVICE,
                                  tr_stats,
                                  args,
                                  trainset,
                                  epoch,
                                  cycle,
                                  training_log,
                                  ALLOW_MULTIGPUS=ALLOW_MULTIGPUS,
コード例 #10
0
        """
        super(ClassWisePooling, self).__init__()

        self.C = classes
        self.M = modalities

    def forward(self, inputs):
        N, C, H, W = inputs.size()
        msg = 'Wrong number of channels, expected {} ' \
              'channels but got {}'.format(self.C * self.M, C)
        assert C == self.C * self.M, msg
        return torch.mean(inputs.view(N, self.C, self.M, -1),
                          dim=2).view(N, self.C, H, W)

    def __str__(self):
        return self.__class__.__name__ +\
               '(classes={}, modalities={})'.format(self.C, self.M)

    def __repr__(self):
        return super(ClassWisePooling, self).__repr__()


if __name__ == "__main__":
    b, c = 10, 2
    reproducibility.reset_seed(0)
    funcs = [WildCatPoolDecision(dropout=0.5)]
    x = torch.randn(b, c, 12, 12)
    for func in funcs:
        out = func(x)
        print(func.__class__.__name__, '->', out.size(), out)
コード例 #11
0
    def __getitem__(self, index):
        """
        Return one sample and its label and extra information that we need
        later.

        :param index: int, the index of the sample within the whole dataset.
        :return: sample: pytorch.tensor of size (1, C, H, W) and datatype
                 torch.FloatTensor. Where C is the number of
                 color channels (=3), and H is the height of the patch, and W
                 is its width.
                 mask: PIL.Image.Image, the mask of the regions of interest.
                 or None if there is no ground truth mask.
                 label: int, the label of the sample.
        """
        # Force seeding: a workaround to deal with reproducibility when suing
        # different number of workers if want to
        # preserve the reproducibility. Each sample has its won seed.
        reproducibility.reset_seed(self.seeds[index])

        if self.set_for_eval:
            error_msg = "Something wrong. You didn't ask to set the data " \
                        "ready for evaluation, but here we are " \
                        ".... [NOT OK]"
            assert self.inputs_ready is not None and \
                self.labels_ready is not None, error_msg
            img = self.inputs_ready[index]
            mask = self.masks_ready[index]
            target = self.labels_ready[index]

            return img, mask, target

        if self.do_not_save_samples:
            img, mask, target = self.load_sample_i(index)
        else:
            assert self.preloaded, "Sorry, you need to preload the data " \
                                   "first .... [NOT OK]"
            img, mask, target = self.images[index], self.masks[index], \
                self.labels[index]

        # Upscale on the fly. Sorry, this may add an extra time, but, we do
        # not want to save in memory upscaled images!!!! it takes a lot of
        # space, especially for large datasets. So, compromise? upscale only
        # when necessary.
        # check if we need to upscale the image.
        # Useful for Caltech-UCSD-Birds-200-2011 for instance.
        if self.up_scale_small_dim_to is not None:
            w, h = img.size
            w_up, h_up = self.get_upscaled_dims(w, h,
                                                self.up_scale_small_dim_to)
            # make a resized copy.
            img = img.resize((w_up, h_up), resample=PIL.Image.BILINEAR)

        # Upscale the image: only for Caltech-UCSD-Birds-200-2011 and similar
        # datasets.

        # crop a patch (training only). Do not crop for evaluation.
        if self.randomCropper:
            msg = "Something's wrong. This is expected to be False." \
                  "We do not crop for evaluation."
            assert not self.for_eval_flag, msg
            # Padding.
            if self.padding_size:
                w, h = img.size
                ph, pw = self.padding_size
                padding = (int(pw * w), int(ph * h))
                img = TF.pad(img,
                             padding=padding,
                             padding_mode=self.padding_mode)
                # just for tracking.
                if mask is not None:
                    mask = TF.pad(mask,
                                  padding=padding,
                                  padding_mode=self.padding_mode)

            img, (i, j, h, w) = self.randomCropper(img)
            # print("Dadaloader Index {} i  {}  j {} seed {}".format(
            # index, i, j, self.seeds[index]))
            # crop the mask just for tracking. Not used for actual training.
            if mask is not None:
                mask = TF.crop(mask, i, j, h, w)

            if self.ratio_scale_patch < 1.:
                img = img.resize((int(img.size[0] * self.ratio_scale_patch),
                                  int(img.size[1] * self.ratio_scale_patch)))

        # rescale the image with the same ration that we use to rescale the
        # cropped patches.
        if self.for_eval_flag and (self.ratio_scale_patch < 1.):
            img = img.resize((int(img.size[0] * self.ratio_scale_patch),
                              int(img.size[1] * self.ratio_scale_patch)))

        # just for training: do not transform the mask (since it is not used).
        if self.transform_img:
            img = self.transform_img(img)

        # just for training: do not transform the mask (since it is not used).
        if self.transform_tensor:
            img = self.transform_tensor(img)

        # Prepare the mask to be used on GPU to compute Dice index.
        if mask is not None:
            mask = np.array(mask, dtype=np.float32) / 255.  # full of 0 and 1.
            # mak the mask with shape (h, w, 1):
            mask = self.to_tensor(np.expand_dims(mask, axis=-1))

        return img, mask, target
def train_one_epoch(model,
                    optimizer,
                    dataloader,
                    criterion,
                    device,
                    tr_stats,
                    args,
                    dataset,
                    epoch=0,
                    cycle=0,
                    log_file=None,
                    ALLOW_MULTIGPUS=False,
                    NBRGPUS=1
                    ):
    """
    Perform one epoch of training.
    :param model: instance of a model.
    :param optimizer: instance of an optimizer.
    :param dataloader: list of two instance of a dataloader: L u L`, U.
    :param criterion: instance of a learning criterion.
    :param device: a device.
    :param tr_stats: dict that holds the states of the training. or
    None.
    :param args: args of the main.py
    :param dataset: the dataset.
    :param epoch: int, the current epoch.
    :param cycle: int, the current cylce.
    :param log_file: a logfile.
    :param ALLOW_MULTIGPUS: bool. If True, we are in multiGPU mode.
    :param NBRGPUS: int, number of GPUs.
    :return:
    """
    reset_seed(seedj(epoch, 3, cycle, CONST1))

    model.train()
    criterion.train()
    metrics = Metrics(threshold=args.seg_threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    t0 = dt.datetime.now()
    # to track stats.
    keys = ["acc", "dice_idx"]
    losses_kz = ["total_loss", "cl_loss", "seg_l_loss", "seg_lp_loss"]
    keys = keys + losses_kz
    tracker = dict()
    for k in keys:
        tracker[k] = 0.

    n_samples = 0.
    n_sam_dice = 0.
    # ignoring samples is on only: 1. cam16 dataset. 2. wsl method.
    # for the rest of the methods, normal samples are completely dropped before
    # starting training.
    cnd_dice = (args.dataset == constants.CAM16)
    cnd_dice &= (args.al_type == constants.AL_WSL)

    for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
            enumerate(dataloader), ncols=80, total=length):
        seedx = int(os.environ["MYSEED"]) + epoch + (cycle + 1) * 10
        reset_seed(seedx)

        targets = None
        masks_pred, masks_trg = None, None
        if mask is not None:
            masks_trg = mask.to(device)

        if (args.al_type == constants.AL_LP) and (args.task == constants.CL):
            targets = build_target_cl(args=args, labels=label, tags=tag,
                                      dataset=dataset, device=device)
            reset_seed(seedx)

        seedx += i

        data = data.to(device)
        labels = label.to(device)

        if targets is not None:
            targets = targets.to(device)

        bsz = data.size()[0]

        model.zero_grad()
        prngs_cuda = None

        # Optimization:
        if not ALLOW_MULTIGPUS:
            if "CC_CLUSTER" in os.environ.keys():
                msg = "Something wrong. You deactivated multigpu mode, " \
                      "but we find {} GPUs. This will not guarantee " \
                      "reproducibility. We do not know why you did that. " \
                      "Exiting ... [NOT OK]".format(NBRGPUS)
                assert NBRGPUS <= 1, msg
            seeds_threads = None
        else:
            msg = "Something is wrong. You asked for multigpu mode. But, " \
                  "we found {} GPUs. Exiting .... [NOT OK]".format(NBRGPUS)
            assert NBRGPUS > 1, msg
            # The seeds are generated randomly before calling the threads.
            reset_seed(seedx)
            seeds_threads = torch.randint(
                0, np.iinfo(np.uint32).max + 1, (NBRGPUS, )).to(device)
            reset_seed(seedx)
            prngs_cuda = []
            # Create different prng states of cuda before forking.
            for seed in seeds_threads:
                # get the corresponding state of the cuda prng with respect to
                # the seed.
                inter_seed = seed.cpu().item()
                # change the internal state of the prng to a random one using
                # the random seed so to capture it.
                torch.manual_seed(inter_seed)
                torch.cuda.manual_seed(inter_seed)
                # capture the prng state.
                prngs_cuda.append(torch.cuda.get_rng_state())
            reset_seed(seedx)

        if prngs_cuda is not None and prngs_cuda != []:
            prngs_cuda = torch.stack(prngs_cuda)

        reset_seed(seedx)
        scores, masks_pred, maps = model(x=data,
                                         seed=seeds_threads,
                                         prngs_cuda=prngs_cuda
                                        )
        reset_seed(seedx)

        # change to the predicted mask to be the maps if WSL.
        if args.al_type == constants.AL_WSL:
            lb_ = labels.view(-1, 1, 1, 1)
            masks_pred = (maps.argmax(dim=1, keepdim=True) == lb_).float()
            # masks_pred contains binary values with float format.
            # shape (bsz, 1, h, w)
        else:
            # TODO: change normalization location. (future)
            masks_pred = torch.sigmoid(masks_pred)  # binary segmentation.


        losses = criterion(scores=scores,
                           labels=labels,
                           targets=targets,
                           masks_pred=masks_pred.view(bsz, -1),
                           masks_trg=masks_trg.view(bsz, -1),
                           tags=tag,
                           weights=None,
                           avg=True
                           )
        reset_seed(seedx)

        losses[0].backward()  # total loss.
        reset_seed(seedx)
        # Update params.
        optimizer.step()
        reset_seed(seedx)
        # End optimization.

        # metrics

        ignore_dice = None
        if cnd_dice:
            # ignore samples with label 'normal' (0) when computing dice.
            ignore_dice = (labels == 0).float().view(-1)


        metrx = metrics(scores=scores,
                        labels=labels,
                        tr_loss=criterion,
                        masks_pred=masks_pred.view(bsz, -1),
                        masks_trg=masks_trg.view(bsz, -1),
                        avg=False,
                        ignore_dice=ignore_dice
                        )

        # Update the tracker.
        tracker["acc"] += metrx[0].item()
        tracker["dice_idx"] += metrx[1].item()

        for j, los in enumerate(losses):
            tracker[losses_kz[j]] += los.item()

        n_samples += bsz
        if cnd_dice:
            n_sam_dice += ignore_dice.numel() - ignore_dice.sum()
        else:
            n_sam_dice += bsz

        # clear the memory
        del scores
        del masks_pred
        del maps

    # average.
    for k in keys:
        if k in losses_kz:
            tracker[k] /= float(length)
        elif (k == 'dice_idx') and cnd_dice:
            announce_msg("Ignore normal samples mode. "
                         "Total samples left: {}/Total: {}.".format(
                n_sam_dice, n_samples))
            if n_sam_dice == 0:
                tracker[k] = 0.
            else:
                tracker[k] /= float(n_sam_dice)
        else:
            tracker[k] /= float(n_samples)

    to_write = "Tr.Ep {:>2d}-{:>2d}: ACC: {:.2f}%, DICE: {:.2f}%, LR: {}, " \
               "time:{}".format(
        cycle,
        epoch,
        tracker["acc"] * 100.,
        tracker["dice_idx"] * 100.,
        ['{:.2e}'.format(group["lr"]) for group in optimizer.param_groups],
        dt.datetime.now() - t0
    )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # Update stats:
    if tr_stats is not None:
        for k in keys:
            tr_stats[k].append(tracker[k])
    else:
        # convert each element into a list
        for k in keys:
            tracker[k] = [tracker[k]]
        tr_stats = deepcopy(tracker)

    reset_seed(seedj(epoch, 4, cycle, CONST1))

    return tr_stats
def train_one_epoch(model, optimizer, dataloader, criterion, device, tr_stats,
                    epoch=0, log_file=None, ALLOW_MULTIGPUS=False, NBRGPUS=1):
    """
    Perform one epoch of training.
    :param model: instance of a model.
    :param optimizer: instance of an optimizer.
    :param dataloader: instance of a dataloader.
    :param criterion: instance of a learning criterion.
    :param device: a device.
    :param tr_stats: numpy matrix that holds the states of the training. or
    None.
    :param epoch: int, the current epoch.
    :param log_file: a logfile.
    :param ALLOW_MULTIGPUS: bool. If True, we are in multiGPU mode.
    :param NBRGPUS: int, number of GPUs.
    :return:
    """
    model.train()
    metrics = Metrics().to(device)

    length = len(dataloader)
    t0 = dt.datetime.now()
    # acc, mae, soi_y, soi_py, loss
    tracker = np.zeros((length, 5), dtype=np.float32)

    for i, (data, masks, labels) in tqdm.tqdm(
            enumerate(dataloader), ncols=80, total=length):
        seedx = int(os.environ["MYSEED"]) + epoch
        reset_seed(seedx)
        seedx += i

        data = data.to(device)
        labels = labels.to(device)

        model.zero_grad()
        prngs_cuda = None

        # Optimization:
        if not ALLOW_MULTIGPUS:
            if "CC_CLUSTER" in os.environ.keys():
                msg = "Something wrong. You deactivated multigpu mode, " \
                      "but we find {} GPUs. This will not guarantee " \
                      "reproducibility. We do not know why you did that. " \
                      "Exiting .... [NOT OK]".format(NBRGPUS)
                assert NBRGPUS <= 1, msg
            seeds_threads = None
        else:
            msg = "Something is wrong. You asked for multigpu mode. But, " \
                  "we found {} GPUs. Exiting .... [NOT OK]".format(NBRGPUS)
            assert NBRGPUS > 1, msg
            # The seeds are generated randomly before calling the threads.
            reset_seed(seedx)
            seeds_threads = torch.randint(
                0, np.iinfo(np.uint32).max + 1, (NBRGPUS, )).to(device)
            reset_seed(seedx)
            prngs_cuda = []
            # Create different prng states of cuda before forking.
            for seed in seeds_threads:
                # get the corresponding state of the cuda prng with respect to
                # the seed.
                inter_seed = seed.cpu().item()
                # change the internal state of the prng to a random one using
                # the random seed so to capture it.
                torch.manual_seed(inter_seed)
                torch.cuda.manual_seed(inter_seed)
                # capture the prng state.
                prngs_cuda.append(torch.cuda.get_rng_state())
            reset_seed(seedx)

        if prngs_cuda is not None and prngs_cuda != []:
            prngs_cuda = torch.stack(prngs_cuda)

        reset_seed(seedx)
        scores, _ = model(x=data, seed=seeds_threads, prngs_cuda=prngs_cuda)
        reset_seed(seedx)
        loss = criterion(scores, labels)
        reset_seed(seedx)
        loss.backward()
        reset_seed(seedx)
        # Update params.
        optimizer.step()
        reset_seed(seedx)
        # End optimization.
        tracker[i, -1] = loss.item()

        # metrics
        batch_metrics = metrics(
            scores=scores, labels=labels, tr_loss=criterion,
            avg=True).cpu().numpy()
        tracker[i, :-1] = batch_metrics

    t_lb = 0.
    if hasattr(criterion.lossCT, "t_lb"):
        t_lb = criterion.lossCT.t_lb.item()  # assume gpu.

    to_write = "Tr.Ep {:>2d}: ACC: {:.4f}, MAE: {:.4f}, SOI_Y: {:.4f}, " \
               "SOI_PY: {:.4f}, Loss: {:.4f}, LR: {}, t: {:.4f}, " \
               "time:{}".format(
                epoch, tracker[:, 0].mean(), tracker[:, 1].mean(),
                tracker[:, 2].mean(), tracker[:, 3].mean(),
                tracker[:, 4].mean(),
                ['{:.2e}'.format(group["lr"]) for group in
                 optimizer.param_groups], t_lb,
                dt.datetime.now() - t0
                )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # Update stats:
    if tr_stats is not None:
        tr_stats = np.vstack([tr_stats, tracker])
    else:
        tr_stats = copy.deepcopy(tracker)
    return tr_stats
def estimate_best_seg_thres(model,
                            dataset,
                            dataloader,
                            criterion,
                            device,
                            args,
                            epoch=0,
                            cycle=0,
                            log_file=None
                            ):
    """
    Perform a validation over a set.
    Allows estimating a segmentation threshold based on this evaluation.

    This is intended fo the `validationset` where we can estimate the best
    threshold since the samples are pixel-wise labeled.

    We specify a set of theresholds, and pick the one that has the best mean
    IOU score. MIOU is better then Dice index.

    Dataset passed here must be in `validation` mode.
    Task: SEG.
    """
    msg = "Can't use/estimate threshold over AL_WSL. masks are already binary."
    assert args.al_type != constants.AL_WSL, msg

    set_default_seed()

    announce_msg("Estimating threshold on validation set.")
    set_default_seed()

    model.eval()
    # the specified threshold here does not matter. we will use a different
    # ones later.
    metrics = Metrics(threshold=args.seg_threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    l_thress = np.arange(start=0.05, stop=1., step=0.01)
    # to track stats.
    avg_dice = np.zeros(l_thress.shape)
    avg_miou = np.zeros(l_thress.shape)

    nbr_ths = l_thress.size

    n_samples = 0.
    n_sam_dice = 0.
    # ignoring samples is on only: 1. cam16 dataset. 2. wsl method.
    # for the rest of the methods, normal samples are completely dropped before
    # starting training.
    cnd_dice = (args.dataset == constants.CAM16)
    cnd_dice &= (args.al_type == constants.AL_WSL)

    t0 = dt.datetime.now()

    with torch.no_grad():
        for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            targets = None  # not needed for SEG task.
            masks_trg = None
            bsz = data.size()[0]
            # assert bsz == 1, "batch size must be 1. found {}. ".format(bsz)

            data = data.to(device)
            labels = label.to(device)
            if mask is not None:
                masks_trg = mask.to(device)

            scores, masks_pred, maps = model(x=data, seed=None)

            # resize pred mask if sizes mismatch.
            if masks_pred.shape != masks_trg.shape:
                _, _, oh, ow = masks_trg.shape
                masks_pred = dataset.turnback_mask_tensor_into_original_size(
                    pred_masks=masks_pred, oh=oh, ow=ow
                )
            # TODO: change normalization location. (future)
            masks_pred = torch.sigmoid(masks_pred)

            ignore_dice = None
            if cnd_dice:
                # ignore samples with label 'normal' (0) when computing dice.
                ignore_dice = (labels == 0).float().view(-1)

            for ii, tt in enumerate(l_thress):
                metrx = metrics(scores=scores,
                                labels=labels,
                                tr_loss=criterion,
                                masks_pred=masks_pred.view(bsz, -1),
                                masks_trg=masks_trg.view(bsz, -1),
                                avg=False,
                                threshold=tt,
                                ignore_dice=ignore_dice
                                )

                avg_dice[ii] += metrx[1].item()
                avg_miou[ii] += metrx[2].item()

            n_samples += bsz  # nbr samples.

            if cnd_dice:
                n_sam_dice += ignore_dice.numel() - ignore_dice.sum()
            else:
                n_sam_dice += bsz

            # clear the memory
            del scores
            del masks_pred

    idx = avg_miou.argmax()
    best_threshold = l_thress[idx]
    if cnd_dice and (n_sam_dice != 0):
        best_dice = avg_dice[idx] / float(n_sam_dice)
        best_miou = avg_miou[idx] / float(n_sam_dice)
    else:
        best_dice = avg_dice[idx] / float(n_samples)
        best_miou = avg_miou[idx] / float(n_samples)

    out = {
        "best_threshold": best_threshold,
        "best_dice": best_dice,
        "best_miou": best_miou
    }

    to_write = "VL [ESTIM-THRESH] {:>2d}-{:>2d}. BEST-DICE: {:.2f}, " \
               "BEST-MIOU: {:.2f}. " \
               "time:{}".format(cycle,
                                epoch,
                                best_dice * 100.,
                                best_miou * 100.,
                                dt.datetime.now() - t0
    )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    to_write =  "best threshold: {:.3f}. " \
                 "BEST-DICE-VL: {:.2f}%, " \
                 "BEST-MIOU-VL: {:.2f}%.".format(
        best_threshold, best_dice * 100., best_miou * 100.
    )
    announce_msg(to_write)
    if log_file:
        log(log_file, to_write)

    set_default_seed()

    return out
def validate(model,
             dataset,
             dataloader,
             criterion,
             device,
             stats,
             args,
             epoch=0,
             cycle=0,
             log_file=None
             ):
    """
    Perform a validation over a set.
    Dataset passed here must be in `validation` mode.
    Task: SEG.
    """
    set_default_seed()

    model.eval()
    criterion.eval()
    metrics = Metrics(threshold=args.seg_threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    # to track stats.
    keys = ["acc", "dice_idx"]
    losses_kz = ["total_loss", "cl_loss", "seg_l_loss", "seg_lp_loss"]
    keys = keys + losses_kz
    tracker = dict()
    for k in keys:
        tracker[k] = 0.

    n_samples = 0.
    n_sam_dice = 0.

    # ignoring samples is on only: 1. cam16 dataset. 2. wsl method.
    # for the rest of the methods, normal samples are completely dropped before
    # starting training.
    cnd_dice = (args.dataset == constants.CAM16)
    cnd_dice &= (args.al_type == constants.AL_WSL)

    t0 = dt.datetime.now()

    with torch.no_grad():
        for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            targets = None  # not needed for SEG task.
            masks_trg = None
            bsz = data.size()[0]

            data = data.to(device)
            labels = label.to(device)
            if mask is not None:
                masks_trg = mask.to(device)

            scores, masks_pred, maps = model(x=data, seed=None)

            check_nans(maps, "fresh-maps")

            # change to the predicted mask to be the maps if WSL.
            if args.al_type == constants.AL_WSL:
                lb_ = labels.view(-1, 1, 1, 1)
                masks_pred = (maps.argmax(dim=1, keepdim=True) == lb_).float()


            # resize pred mask if sizes mismatch.
            if masks_pred.shape != masks_trg.shape:
                _, _, oh, ow = masks_trg.shape
                masks_pred = dataset.turnback_mask_tensor_into_original_size(
                    pred_masks=masks_pred, oh=oh, ow=ow
                )
                check_nans(masks_pred, "mask-pred-back-to-normal-size")

            if args.al_type != constants.AL_WSL:
                # TODO: change normalization location. (future)
                masks_pred = torch.sigmoid(masks_pred)  # binary segmentation

            losses = criterion(scores=scores,
                               labels=labels,
                               targets=targets,
                               masks_pred=masks_pred.view(bsz, -1),
                               masks_trg=masks_trg.view(bsz, -1),
                               tags=tag,
                               weights=None,
                               avg=False
                               )

            # metrics
            ignore_dice = None
            if cnd_dice:
                # ignore samples with label 'normal' (0) when computing dice.
                ignore_dice = (labels == 0).float().view(-1)

            metrx = metrics(scores=scores,
                            labels=labels,
                            tr_loss=criterion,
                            masks_pred=masks_pred.view(bsz, -1),
                            masks_trg=masks_trg.view(bsz, -1),
                            avg=False,
                            ignore_dice=ignore_dice
                            )

            tracker["acc"] += metrx[0].item()
            tracker["dice_idx"] += metrx[1].item()

            # print("Dice valid i: {} {}. datashape {} id {}".format(
            #     i, metrx[1].item() * 100., data.shape, ids))


            for j, los in enumerate(losses):
                tracker[losses_kz[j]] += los.item()

            n_samples += bsz  # nbr samples.

            if cnd_dice:
                n_sam_dice += ignore_dice.numel() - ignore_dice.sum()
            else:
                n_sam_dice += bsz

            # clear the memory
            del scores
            del masks_pred
            del maps

    # average.
    for k in keys:
        if (k == 'dice_idx') and cnd_dice:
            announce_msg("Ignore normal samples mode. "
                         "Total samples left: {}/Total: {}.".format(
                n_sam_dice, n_samples))
            if n_sam_dice == 0:
                tracker[k] = 0.
            else:
                tracker[k] /= float(n_sam_dice)
        else:
            tracker[k] /= float(n_samples)

    to_write = "VL {:>2d}-{:>2d}: ACC: {:.4f}, DICE: {:.4f}, time:{}".format(
                cycle, epoch, tracker["acc"] * 100., tracker["dice_idx"] * 100.,
                dt.datetime.now() - t0
                 )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # Update stats.
    if stats is not None:
        for k in keys:
            stats[k].append(tracker[k])
    else:
        # convert each element into a list
        for k in keys:
            tracker[k] = [tracker[k]]
        stats = deepcopy(tracker)

    set_default_seed()

    return stats
    # ################################ Training ################################
    set_default_seed()
    tr_stats, vl_stats = None, None

    best_val_metric = None
    best_epoch = 0

    # TODO: validate before start training.

    announce_msg("start training")
    set_default_seed()
    tx0 = dt.datetime.now()

    for epoch in range(args.max_epochs):
        # reseeding tr/vl samples.
        reset_seed(int(os.environ["MYSEED"]) + (epoch + 1) * CONST1)
        trainset.set_up_new_seeds()
        reset_seed(int(os.environ["MYSEED"]) + (epoch + 2) * CONST1)
        validset.set_up_new_seeds()

        # Start the training with fresh seeds.
        reset_seed(int(os.environ["MYSEED"]) + (epoch + 3) * CONST1)

        tr_stats = train_one_epoch(model,
                                   optimizer,
                                   train_loader,
                                   CRITERION,
                                   DEVICE,
                                   tr_stats,
                                   epoch,
                                   training_log,
def validate(model, dataloader, criterion, device, stats, epoch=0,
             log_file=None):
    """
    Perform a validation over the validation set. Assumes a batch size of 1.
    (images do not have the same size, so we can't stack them in one tensor).
    Validation samples may be large to fit all in the GPU at once.
    """
    model.eval()
    metrics = Metrics().to(device)

    length = len(dataloader)
    # acc, mae, soi_y, soi_py, loss
    tracker = np.zeros((1, 5), dtype=np.float32)
    t0 = dt.datetime.now()

    with torch.no_grad():
        for i, (data, mask, label) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]) + epoch)

            msg = "Expected a batch size of 1. Found `{}`  .... " \
                  "[NOT OK]".format(data.size()[0])
            assert data.size()[0] == 1, msg

            data = data.to(device)
            labels = label.to(device)

            # In validation, we do not need reproducibility since everything
            # is expected to deterministic. Plus,
            # we use only one gpu since the batch size os 1.
            scores, _ = model(x=data, seed=None)
            loss = criterion(scores, labels)
            batch_metrics = metrics(
                scores=scores, labels=labels, tr_loss=criterion,
                avg=False).cpu().numpy()

            tracker[0, -1] += loss.item()
            tracker[0, :-1] += batch_metrics

    tracker /= float(length)
    t_lb = 0.
    if hasattr(criterion.lossCT, "t_lb"):
        t_lb = criterion.lossCT.t_lb.item()  # assume gpu.

    to_write = "Vl.Ep {:>2d}: ACC: {:.4f}, MAE: {:.4f}, SOI_Y: {:.4f}, " \
               "SOI_PY: {:.4f}, Loss: {:.4f}, t:{:.4f}, time:{}".format(
                epoch, tracker[:, 0].mean(), tracker[:, 1].mean(),
                tracker[:, 2].mean(), tracker[:, 3].mean(),
                tracker[:, 4].mean(), t_lb,
                dt.datetime.now() - t0)
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # Update stats
    if stats is not None:
        stats = np.vstack([stats, tracker])
    else:
        stats = copy.deepcopy(tracker)
    return stats
def _pseudo_annotate(model,
                     dataset,
                     dataloader,
                     criterion,
                     device,
                     pairs,
                     metrics_fd,
                     fd_p_msks,
                     args,
                     SHARED_OPT_MASKS,
                     threshold=0.5
                     ):
    """
    Perform a validation over a set.
    Allows storing pseudo-masks.


    Dataset passed here must be in `validation` mode.
    Task: SEG.
    """
    set_default_seed()

    model.eval()
    metrics = Metrics(threshold=threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    # to track stats.
    sum_dice = 0.
    n_samples = 0.

    with torch.no_grad():
        for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            targets = None  # not needed for SEG task.
            masks_trg = None
            bsz = data.size()[0]
            assert bsz == 1, "batch size must be 1. found {}. ".format(bsz)

            id_u = ids[0]
            id_l = pairs[id_u]
            name_file = "{}-{}".format(id_u, id_l)
            path_to_save_mask = join(fd_p_msks, "{}.bmp".format(name_file))
            path_to_save_metric = join(metrics_fd, "{}.pkl".format(name_file))

            data = data.to(device)
            labels = label.to(device)
            if mask is not None:
                masks_trg = mask.to(device)

            scores, masks_pred, maps = model(x=data, seed=None)

            # resize pred mask if sizes mismatch.
            if masks_pred.shape != masks_trg.shape:
                _, _, oh, ow = masks_trg.shape
                masks_pred = dataset.turnback_mask_tensor_into_original_size(
                    pred_masks=masks_pred, oh=oh, ow=ow
                )
            # TODO: change normalization location. (future)
            masks_pred = torch.sigmoid(masks_pred)

            metrx = metrics(scores=scores,
                            labels=labels,
                            tr_loss=criterion,
                            masks_pred=masks_pred.view(bsz, -1),
                            masks_trg=masks_trg.view(bsz, -1),
                            avg=False
                            )

            sum_dice += metrx[1].item()

            n_samples += bsz  # nbr samples.

            # store the mask and metric. =======================================
            bin_mask = metrics.get_binary_mask(masks_pred, threshold=threshold)
            bin_mask = bin_mask.detach().cpu().squeeze().numpy()
            # issue with mode=1...
            # https://stackoverflow.com/questions/32159076/python-pil-bitmap-png-
            # from-array-with-mode-1
            img_mask = Image.fromarray(bin_mask.astype(np.uint8) * 255,
                                       mode='L').convert('1')

            img_mask.save(path_to_save_mask)
            with open(path_to_save_metric, "wb") as fout:
                pkl.dump(
                    {"dice_u": metrx[1].item()}, fout,
                    protocol=pkl.HIGHEST_PROTOCOL)

            # ==================================================================


            # clear the memory
            del scores
            del masks_pred

    # clean: remove duplicates.
    clean_shared_masks(args=args,
                       SHARED_OPT_MASKS=SHARED_OPT_MASKS,
                       metrics_fd=metrics_fd,
                       all_pairs=pairs
                       )

    set_default_seed()

    return sum_dice, n_samples
def final_validate(model, dataloader, criterion, device, dataset, outd,
                   log_file=None, name_set=""):
    """
    Perform a validation over the validation set. Assumes a batch size of 1.
    (images do not have the same size, so we can't stack them in one tensor).
    Validation samples may be large to fit all in the GPU at once.

    :param outd: str, output directory of this dataset.
    :param name_set: str, name to indicate which set is being processed. e.g.:
    trainset, validset, testset.
    """
    visualisor = VisualisePP(floating=4, height_tag=60)
    outd_data = join(outd, "prediction")
    if not os.path.exists(outd_data):
        os.makedirs(outd_data)

    # Deal with overloaded quota of files on servers: use the node disc.
    FOLDER = ""
    if "CC_CLUSTER" in os.environ.keys():
        FOLDER = join(os.environ["SLURM_TMPDIR"], "prediction")
        if not os.path.exists(FOLDER):
            os.makedirs(FOLDER)
    model.eval()
    metrics = Metrics().to(device)

    length = len(dataloader)
    num_classes = len(list(dataset.name_classes.keys()))
    # acc, mae, soi_y, soi_py, loss
    tracker = np.zeros((length, 5 + num_classes), dtype=np.float32)

    t0 = dt.datetime.now()

    with torch.no_grad():
        for i, (data, mask, label) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            msg = "Expected a batch size of 1. Found `{}`  .... " \
                  "[NOT OK]".format(data.size()[0])
            assert data.size()[0] == 1, msg

            data = data.to(device)
            labels = label.to(device)

            # In validation, we do not need reproducibility since everything
            # is expected to be deterministic. Plus,
            # we use only one gpu since the batch size os 1.
            scores, _ = model(x=data, seed=None)
            loss = criterion(scores, labels)
            batch_metrics = metrics(
                scores=scores, labels=labels, tr_loss=criterion,
                avg=False).cpu().numpy()

            tracker[i, 4] = loss.item()
            tracker[i, :4] = batch_metrics
            tracker[i, 5:] = softmax(scores.cpu().detach().numpy())

            basef = basename(dataset.get_path_input_img(i))
            img_out = visualisor(
                input_img=dataset.get_original_input_img(i),
                stats=[tracker[i, :]],
                label=dataset.get_original_input_label_int(i),
                name_classes=dataset.name_classes,
                loss_name=[criterion.literal],
                name_file=basef
            )
            fdout = FOLDER if FOLDER else outd_data
            img_out.save(
                join(fdout, "{}.jpeg".format(basef.split('.')[0])), "JPEG")

    # overlap distributions.
    # vis_over_dis = VisualiseOverlDist()
    # vis_over_dis(tracker[:, 5:], dataset.name_classes, outd)

    # compress, then delete files to prevent overloading the disc quota of
    # number of files.
    source = FOLDER if FOLDER else outd_data
    ex = 'zip'
    try:
        cmd_compress = 'zip -rjq {}.zip {}'.format(source, source)
        print("Run: `{}`".format(cmd_compress))
        # os.system(cmd_compress)
        subprocess.run(cmd_compress, shell=True, check=True)
    except subprocess.CalledProcessError:
        cmd_compress = 'tar -zcf {}.tar.gz -C {} .'.format(source, source)
        print("Run: `{}`".format(cmd_compress))
        # os.system(cmd_compress)
        subprocess.run(cmd_compress, shell=True, check=True)
        ex = 'tar.gz'

    cmd_del = 'rm -r {}'.format(outd_data)
    print("Run: `{}`".format(cmd_del))
    os.system(cmd_del)
    if FOLDER:
        shcopy("{}.{}".format(FOLDER, ex), outd)

    tmp = tracker.mean(axis=0)

    t_lb = 0.
    if hasattr(criterion.lossCT, "t_lb"):
        t_lb = criterion.lossCT.t_lb.item()  # assume gpu.

    to_write = "EVAL.FINAL {}: ACC: {:.4f}, MAE: {:.4f}, SOI_Y: {:.4f}, " \
               "SOI_PY: {:.4f}, Loss: {:.4f}, t:{:.4f}, time:{}".format(
                name_set, tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], t_lb,
                dt.datetime.now() - t0)
    to_write = "{} \n{} \n{}".format(10 * "=", to_write, 10 * "=")
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # store the stats in pickle.
    with open(join(outd, 'tracker-{}.pkl'.format(name_set)), 'wb') as fout:
        pkl.dump(tracker, fout, protocol=pkl.HIGHEST_PROTOCOL)
def final_validate(model,
                   dataloader,
                   criterion,
                   device,
                   dataset,
                   outd,
                   args,
                   log_file=None,
                   name_set="",
                   pseudo_labeled=False,
                   store_results=False,
                   apply_selection_tech=False
                   ):
    """
    Perform a final evaluation of a set.
    (images do not have the same size, so we can't stack them in one tensor).
    Validation samples may be large to fit all in the GPU at once.

    This is similar to validate() but it can operate on all types of
    datasets (train, valid, test). It selects only L. Also, it does some other
    operations related to drawing and final training tasks.

    :param outd: str, output directory of this dataset.
    :param name_set: str, name to indicate which set is being processed. e.g.:
           trainset, validset, testset.
    :param pseudo_labeled: bool. if true, the dataloader is loading the set
           of samples that has been pseudo-labeled. this is the case only for
           our method. if false, the samples are fully labeled. this allows to
          measure the performance over the pseudo-labeled samples separately.
    :param store_results: bool (default: False). if true, results (
           predictions) are stored. Useful for test set to store predictions.
           can be disabled for train/valid sets.
    :param apply_selection_tech: bool. if true, we compute stats that will be
           used for sample selection for the nex AL round. useful on the
           leftovers.
    """
    set_default_seed()

    outd_data = join(outd, "prediction")
    if not os.path.exists(outd_data):
        os.makedirs(outd_data)

    # TODO: rescale depending on the dataset.
    visualisor = VisualsePredSegmentation(height_tag=args.height_tag,
                                          show_tags=True,
                                          scale=0.5
                                          )

    model.eval()
    criterion.eval()
    metrics = Metrics(threshold=args.seg_threshold).to(device)
    metrics.eval()

    length = len(dataloader)
    # to track stats.
    keys_perf = ["acc", "dice_idx"]
    losses_kz = ["total_loss", "cl_loss", "seg_l_loss", "seg_lp_loss"]
    per_sample_kz = ["pred_labels", "pred_masks", "ids"]
    keys = keys_perf + losses_kz + per_sample_kz
    tracker = dict()
    for k in keys:
        if k not in per_sample_kz:
            tracker[k] = 0.
        else:
            tracker[k] = []

    n_samples = 0.
    n_sam_dice = 0.
    # ignoring samples is on only: 1. cam16 dataset. 2. wsl method.
    # for the rest of the methods, normal samples are completely dropped before
    # starting training.
    cnd_dice = (args.dataset == constants.CAM16)
    cnd_dice &= (args.al_type == constants.AL_WSL)

    # Selection criteria.
    cp_entropy = Entropy()  # to copmute entropy.
    entropy = []
    mcdropout_var = []

    t0 = dt.datetime.now()

    # conditioning
    entropy_cnd = (args.al_type == constants.AL_ENTROPY)
    entropy_cnd = entropy_cnd or ((args.al_type == constants.AL_LP) and (
            args.clustering == constants.CLUSTER_ENTROPY))
    mcdropout_cnd = (args.al_type == constants.AL_MCDROPOUT)

    # cnd_asrt = (args.task == constants.SEG)
    cnd_asrt = store_results
    cnd_asrt |= apply_selection_tech

    with torch.no_grad():
        for i, (ids, data, mask, label, tag, crop_cord) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reset_seed(int(os.environ["MYSEED"]))

            targets = None
            bsz = data.size()[0]

            if cnd_asrt:
                msg = "batchsize must be 1 for segmentation task. " \
                      "found {}.".format(bsz)
                assert bsz == 1, msg

            data = data.to(device)
            labels = label.to(device)
            masks_trg = mask.to(device)

            scores, masks_pred, maps = model(x=data, seed=None)

            # change to the predicted mask to be the maps if WSL.
            if args.al_type == constants.AL_WSL:
                lb_ = labels.view(-1, 1, 1, 1)
                masks_pred = (maps.argmax(dim=1, keepdim=True) == lb_).float()

            # resize pred mask if sizes mismatch.
            if masks_pred.shape != masks_trg.shape:
                _, _, oh, ow = masks_trg.shape
                masks_pred = dataset.turnback_mask_tensor_into_original_size(
                    pred_masks=masks_pred,
                    oh=oh,
                    ow=ow
                )

            if args.al_type != constants.AL_WSL:
                # TODO: change normalization location. (future)
                masks_pred = torch.sigmoid(masks_pred)  # binary segmentation

            losses_ = criterion(scores=scores,
                                labels=labels,
                                targets=targets,
                                masks_pred=masks_pred.view(bsz, -1),
                                masks_trg=masks_trg.view(bsz, -1),
                                tags=tag,
                                weights=None,
                                avg=False
                                )

            # metrics
            ignore_dice = None
            if cnd_dice:
                # ignore samples with label 'normal' (0) when computing dice.
                ignore_dice = (labels == 0).float().view(-1)

            metrx = metrics(scores=scores,
                            labels=labels,
                            tr_loss=criterion,
                            masks_pred=masks_pred.view(bsz, -1),
                            masks_trg=masks_trg.view(bsz, -1),
                            avg=False,
                            ignore_dice=ignore_dice
                            )
            tracker["acc"] += metrx[0].item()
            tracker["dice_idx"] += metrx[1].item()

            for j, los in enumerate(losses_):
                tracker[losses_kz[j]] += los.item()

            # stored things
            # always store ids.
            tracker["ids"].extend(ids)

            if store_results:
                tracker["pred_labels"].extend(
                    scores.argmax(dim=1, keepdim=False).cpu().numpy().tolist()
                )
                # store binary masks
                tracker["pred_masks"].extend(
                    [(metrics.get_binary_mask(
                        masks_pred[kk]).detach().cpu().numpy() > 0.) for kk in
                     range(bsz)]
                )

            n_samples += bsz
            if cnd_dice:
                n_sam_dice += ignore_dice.numel() - ignore_dice.sum()
            else:
                n_sam_dice += bsz

            # ==================================================================
            #                    START: COMPUTE INFO. FOR SELECTION CRITERIA.
            # ==================================================================
            # 1. Entropy
            if entropy_cnd and apply_selection_tech:
                if args.task == constants.CL:
                    entropy.extend(
                        cp_entropy(
                            F.softmax(scores, dim=1)).cpu().numpy().tolist()
                    )
                elif args.task == constants.SEG:
                    assert bsz == 1, "batchsize must be 1. " \
                                     "found {}.".format(bsz)
                    # nbr_pixels, 2 (forg, backg).
                    pixel_probs = torch.cat((masks_pred.view(-1, 1),
                                             masks_pred.view(-1, 1)), dim=1)
                    avg_pixel_entropy = cp_entropy(pixel_probs).mean()
                    entropy.append(avg_pixel_entropy.cpu().item())
                else:
                    raise ValueError("Unknown task {}.".format(args.task))

            # 1. MC-Dropout
            if mcdropout_cnd and apply_selection_tech:

                reset_seed(int(os.environ["MYSEED"]))

                if args.task == constants.CL:
                    # todo
                    raise NotImplementedError
                elif args.task == constants.SEG:
                    assert bsz == 1, "batchsize must be 1. " \
                                     "found {}.".format(bsz)
                    stacked_masks = None
                    # turn on dropout
                    model.set_dropout_to_train_mode()
                    for it_mc in range(args.mcdropout_t):
                        scores, masks_pred, maps = model(x=data, seed=None)

                        # resize pred mask if sizes mismatch.
                        if masks_pred.shape != masks_trg.shape:
                            _, _, oh, ow = masks_trg.shape
                            masks_pred = \
                                dataset.turnback_mask_tensor_into_original_size(
                                    pred_masks=masks_pred, oh=oh, ow=ow
                                )

                        # TODO: change normalization location. (future)
                        masks_pred = torch.sigmoid(masks_pred)
                        # stack flatten masks horizontally.
                        if stacked_masks is None:
                            stacked_masks = masks_pred.view(-1, 1)
                        else:
                            stacked_masks = torch.cat(
                                (stacked_masks, masks_pred.view(-1, 1)),
                                dim=0
                            )
                    # compute variance per pixel, then average over the
                    # image. images have different sizes. if it is not the
                    # case, it is fine since we divide all sum_var with a
                    # constant.
                    variance = stacked_masks.var(
                        dim=0, unbiased=True).mean()
                    mcdropout_var.append(variance.cpu().item())

                    # turn off dropout
                    model.set_dropout_to_eval_mode()
                else:
                    raise ValueError(
                        "Unknown task {}.".format(args.task))
            # ==================================================================
            #                    END: COMPUTE INFO. FOR SELECTION CRITERIA.
            # ==================================================================

            # Visualize
            cnd = (args.task == constants.SEG)
            cnd &= (args.subtask == constants.SUBCLSEG)
            cnd &= store_results

            # todo: separate storing stats from plotting figure + storing it.
            # todo: add plot_figure var.
            if cnd:
                output_file = join(outd_data, "{}.jpeg".format(ids[0]))

                visualisor(
                    img_in=dataset.get_original_input_img(i),
                    mask_pred=metrics.get_binary_mask(
                        masks_pred).detach().cpu().squeeze().numpy(),
                    true_label=dataset.get_original_input_label_int(i),
                    label_pred=scores.argmax(dim=1, keepdim=False).cpu().numpy(
                    ).tolist()[0],
                    id_sample=ids[0],
                    name_classes=dataset.name_classes,
                    true_mask=masks_trg.cpu().squeeze().numpy(),
                    dice=metrx[1].item(),
                    output_file=output_file,
                    scale=None,
                    binarize_pred_mask=False,
                    cont_pred_msk=masks_pred.detach().cpu().squeeze().numpy()
                )

            # clear memory
            del scores
            del masks_pred
            del maps

    # avg: acc, dice_idx
    for k in keys_perf:
        if (k == 'dice_idx') and cnd_dice:
            announce_msg("Ignore normal samples mode. "
                         "Total samples left: {}/Total: {}.".format(
                n_sam_dice, n_samples))
            if n_sam_dice != 0:
                tracker[k] /= float(n_sam_dice)
            else:
                tracker[k] = 0.

        else:
            tracker[k] /= float(n_samples)

    # compress, then delete files to prevent overloading the disc quota of
    # number of files.
    def compress_del(src):
        """
        Compress a folder with name 'src' into 'src.zip' or 'src.tar.gz'.
        Then, delete the folder 'src'.
        :param src: str, absolute path to the folder to compress.
        :return:
        """
        try:
            cmd_compress = 'zip -rjq {}.zip {}'.format(src, src)
            print("Run: `{}`".format(cmd_compress))
            subprocess.run(cmd_compress, shell=True, check=True)
        except subprocess.CalledProcessError:
            cmd_compress = 'tar -zcf {}.tar.gz -C {} .'.format(src, src)
            print("Run: `{}`".format(cmd_compress))
            subprocess.run(cmd_compress, shell=True, check=True)

        cmd_del = 'rm -r {}'.format(src)
        print("Run: `{}`".format(cmd_del))
        os.system(cmd_del)

    compress_del(outd_data)

    to_write = "EVAL.FINAL {} -- pseudo-labeled {}: ACC: {:.3f}%," \
               " DICE: {:.3f}%, time:{}".format(
                name_set, pseudo_labeled, tracker["acc"] * 100.,
                tracker["dice_idx"] * 100., dt.datetime.now() - t0
               )
    to_write = "{} \n{} \n{}".format(10 * "=", to_write, 10 * "=")
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # store the stats in pickle.
    final_stats = dict()  # for a quick access: perf.
    for k in keys_perf:
        final_stats[k] = tracker[k] * 100.

    pred_stats = dict()  # it is heavy.
    for k in losses_kz + per_sample_kz:
        pred_stats[k] = tracker[k]

    pred_stats["pseudo-labeled"] = pseudo_labeled

    if pseudo_labeled:
        outfile_tracker = join(outd, 'final-tracker-{}-pseudoL-{}.pkl'.format(
            name_set, pseudo_labeled))
        outfile_pred = join(outd, 'final-pred-{}-pseudoL-{}.pkl'.format(
            name_set, pseudo_labeled))
    else:
        outfile_tracker = join(outd, 'final-tracker-{}.pkl'.format(name_set))
        outfile_pred = join(outd, 'final-pred-{}.pkl'.format(name_set))

    with open(outfile_tracker, 'wb') as fout:
        pkl.dump(final_stats, fout, protocol=pkl.HIGHEST_PROTOCOL)
    with open(outfile_pred, 'wb') as fout:
        pkl.dump(pred_stats, fout, protocol=pkl.HIGHEST_PROTOCOL)

    # store info. for selection techniques.
    # 1. Entropy.
    if entropy_cnd and apply_selection_tech:
        with open(join(outd, 'entropy-{}.pkl'.format(name_set)), 'wb') as fout:
            pkl.dump({'entropy': entropy,
                      'ids': tracker["ids"]}, fout,
                     protocol=pkl.HIGHEST_PROTOCOL)

    # 2. MC-Dropout.
    if mcdropout_cnd and apply_selection_tech:
        with open(join(outd, 'mc-dropout-{}.pkl'.format(name_set)),
                  'wb') as fout:
            pkl.dump({'mc-dropout-var': mcdropout_var,
                      'ids': tracker["ids"]}, fout,
                     protocol=pkl.HIGHEST_PROTOCOL)

    set_default_seed()