def forward(self, x, seed=None, prngs_cuda=None):
        """
        Input:
            In the case of K classes:
                x: torch tensor of size (n, c, h, w), where n is the batch size, c is the number of classes,
                h is the height of the feature map, w is its width.
            seed: int, seed for the thread to guarantee reproducibility over a fixed number of gpus.
        Output:
            scores: torch vector of size (k). Contains the wildcat score of each class. A score is a linear combination
            of different features. The class with the highest features is the winner.
        """
        b, c, h, w = x.shape
        activations = x.view(b, c, h * w)

        n = h * w

        sorted_features = torch.sort(activations, dim=-1, descending=True)[0]
        kmax = self.get_k(self.kmax, n)
        kmin = self.get_k(self.kmin, n)

        # dropout
        if self.dropout != 0.:
            if seed is not None:
                thread_lock.acquire()
                assert prngs_cuda is not None, "`prngs_cuda` is expected to not be None. Exiting .... [NOT OK]"
                prng_state = (torch.cuda.get_rng_state().cpu())
                reproducibility.force_seed(seed)
                torch.cuda.set_rng_state(prngs_cuda.cpu())

                sorted_features = self.dropout_md(
                    sorted_features)  # instruction that causes randomness.

                reproducibility.force_seed(seed)
                torch.cuda.set_rng_state(prng_state)
                thread_lock.release()
            else:
                sorted_features = self.dropout_md(sorted_features)

        scores = sorted_features.narrow(-1, 0, kmax).sum(-1).div_(kmax)

        if kmin > 0 and self.alpha != 0.:
            scores.add(
                sorted_features.narrow(-1, n - kmin, kmin).sum(-1).mul_(
                    self.alpha / kmin)).div_(2.)

        return scores
def test__LossExtendedLB():
    force_seed(0, check_cudnn=False)
    instance = _LossExtendedLB(init_t=1., max_t=10., mulcoef=1.01)
    announce_msg("Testing {}".format(instance))

    cuda = 1
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        torch.cuda.set_device(int(cuda))
    instance.to(DEVICE)

    b = 16
    fx = (torch.rand(b)).to(DEVICE)

    out = instance(fx)
    for r in range(10):
        instance.update_t()
        print("epoch {}. t: {}.".format(r, instance.t_lb))
    print("Loss ELB.sum(): {}".format(out))
def get_eval_dataset(args, myseed, valid_samples, transform_tensor):
    """
    Return dataset and its dataloader.
    :return:
    """
    reproducibility.force_seed(myseed)
    pad_vld_sz = None if not args.pad_eval else args.padding_size
    pad_vl_md = None if not args.pad_eval else args.padding_mode
    validset = PhotoDataset(valid_samples,
                            args.dataset,
                            args.name_classes,
                            transform_tensor,
                            set_for_eval=False,
                            transform_img=None,
                            resize=args.resize,
                            crop_size=None,
                            padding_size=pad_vld_sz,
                            padding_mode=pad_vl_md,
                            force_div_32=False,
                            up_scale_small_dim_to=args.up_scale_small_dim_to)

    reproducibility.force_seed(myseed)
    valid_loader = DataLoader(validset,
                              batch_size=1,
                              shuffle=False,
                              num_workers=args.num_workers *
                              FACTOR_MUL_WORKERS,
                              pin_memory=True,
                              collate_fn=default_collate,
                              worker_init_fn=_init_fn
                              )  # we need more workers since the batch size is
    # 1, and set_for_eval is False (need more time to prepare a sample).
    reproducibility.force_seed(myseed)
    return validset, valid_loader
Example #4
0
            os.path.isfile(valid_csv),
            os.path.isfile(test_csv)
    ]):
        raise ValueError("Missing *.cvs files ({}[{}], {}[{}], {}[{}])".format(
            train_csv, os.path.isfile(train_csv), valid_csv,
            os.path.isfile(valid_csv), test_csv, os.path.isfile(test_csv)))

    rootpath = get_rootpath_2_dataset(args)

    train_samples = csv_loader(train_csv, rootpath)
    valid_samples = csv_loader(valid_csv, rootpath)
    test_samples = csv_loader(test_csv, rootpath)

    # Just for debug to go fast.
    if DEBUG_MODE and (args.dataset == "Caltech-UCSD-Birds-200-2011"):
        reproducibility.force_seed(int(os.environ["MYSEED"]))
        warnings.warn("YOU ARE IN DEBUG MODE!!!!")
        train_samples = random.sample(train_samples, 100)
        valid_samples = random.sample(valid_samples, 5)
        test_samples = test_samples[:20]
        reproducibility.force_seed(int(os.environ["MYSEED"]))

    if DEBUG_MODE and (args.dataset == "Oxford-flowers-102"):
        reproducibility.force_seed(int(os.environ["MYSEED"]))
        warnings.warn("YOU ARE IN DEBUG MODE!!!!")
        # train_samples = random.sample(train_samples, 100)
        # valid_samples = random.sample(valid_samples, 5)
        # test_samples = test_samples[:20]
        reproducibility.force_seed(int(os.environ["MYSEED"]))

    if DEBUG_MODE and (args.dataset == "glas"):
def train_one_epoch(model,
                    optimizer,
                    dataloader,
                    criterion,
                    device,
                    tr_stats,
                    args,
                    epoch=0,
                    log_file=None,
                    ALLOW_MULTIGPUS=False,
                    NBRGPUS=1
                    ):
    """
    Perform one epoch of training.
    :param model:
    :param optimizer:
    :param dataloader:
    :param criterion:
    :param device:
    :param epoch:
    :param callback:
    :param log_file:
    :param ALLOW_MULTIGPUS: bool. If True, we are in multiGPU mode.
    :return:
    """
    model.train()

    metrics = Metrics(threshold=args.final_thres).to(device)
    metrics.eval()

    f1pos_tr, f1neg_tr, miou_tr, acc_tr = 0., 0., 0., 0.
    cnt = 0.

    length = len(dataloader)
    t0 = dt.datetime.now()
    myseed = int(os.environ["MYSEED"])

    for i, (data, masks, labels) in tqdm.tqdm(
            enumerate(dataloader), ncols=80, total=length):
        reproducibility.force_seed(myseed + epoch)

        data = data.to(device)
        labels = labels.to(device)
        masks = torch.stack(masks)
        masks = masks.to(device)

        model.zero_grad()

        t_l, l_p, l_n, l_c_s = 0., 0., 0., 0.
        prngs_cuda = None  # TODO: crack in optimal code.
        bsz = data.shape[0]

        # Optimization:
        # if model.nbr_times_erase == 0:  # no erasing.
        if not ALLOW_MULTIGPUS:
            # TODO: crack in optimal code.
            if "CC_CLUSTER" in os.environ.keys():
                msg = "Something wrong. You deactivated multigpu mode, " \
                      "but we find {} GPUs. This will not guarantee " \
                      "reproducibility. We do not know why you did that. " \
                      "Exiting .... [NOT OK]".format(NBRGPUS)
                assert NBRGPUS <= 1, msg
            seeds_threads = None
        else:
            msg = "Something is wrong. You asked for multigpu mode. " \
                  "But, we found {} GPUs. Exiting " \
                  ".... [NOT OK]".format(NBRGPUS)
            assert NBRGPUS > 1, msg
            # The seeds are generated randomly before calling the threads.
            reproducibility.force_seed(myseed + epoch + i)  # armor.
            seeds_threads = torch.randint(
                0, np.iinfo(np.uint32).max + 1, (NBRGPUS, )).to(device)
            reproducibility.force_seed(myseed + epoch + i)  # armor.
            prngs_cuda = []
            # Create different prng states of cuda before forking.
            for seed in seeds_threads:
                # get the corresponding state of the cuda prng with respect
                # to the seed.
                torch.manual_seed(seed)
                torch.cuda.manual_seed(seed)
                prngs_cuda.append(torch.cuda.get_rng_state())
            reproducibility.force_seed(myseed + epoch + i)  # armor.

        # TODO: crack in optimal code.
        if prngs_cuda is not None and prngs_cuda != []:
            prngs_cuda = torch.stack(prngs_cuda)

        reproducibility.force_seed(myseed + epoch + i)  # armor.
        scores_pos, scores_neg, mask_pred, sc_cl_se = model(
            x=data,
            seed=seeds_threads,
            prngs_cuda=prngs_cuda
        )
        reproducibility.force_seed(myseed + epoch + i)  # armor.

        msg = "shape mismatches: pred {}  true {}".format(
            masks.shape, mask_pred.shape)
        assert masks.shape == mask_pred.shape, msg

        t_loss, l_p, l_n, l_seg = criterion(scores_pos,
                                            sc_cl_se,
                                            labels,
                                            mask_pred,
                                            scores_neg
                                            )
        t_loss.backward()

        # Update params.
        optimizer.step()
        # End optimization.
        acc, dice_forg, dice_back, miou = metrics(
            scores=scores_pos,
            labels=labels,
            masks_pred=mask_pred.contiguous().view(bsz, -1),
            masks_trg=masks.contiguous().view(bsz, -1),
            avg=True
            )

        # tracking
        tr_stats["total_loss"].append(t_loss.item())
        tr_stats["loss_pos"].append(l_p.item())
        tr_stats["loss_neg"].append(l_n.item())
        tr_stats["acc"].append(acc * 100.)
        tr_stats["f1pos"].append(dice_forg * 100.)
        tr_stats["f1neg"].append(dice_back * 100.)
        tr_stats['miou'].append(miou * 100.)

        f1pos_tr += dice_forg
        f1neg_tr += dice_back
        miou_tr += miou
        acc_tr += acc
        cnt += bsz

    # avg
    f1neg_tr = f1neg_tr * 100. / float(cnt)
    f1pos_tr = f1pos_tr * 100. / float(cnt)
    acc_tr = acc_tr * 100. / float(cnt)
    miou_tr = miou_tr * 100. / float(cnt)

    to_write = "Train epoch {:>2d}: f1+: {:.2f}, f1-: {:.2f}, " \
               "miou: {:.2f}, acc: {:.2f}, LR {}, t:{}".format(
                epoch, f1pos_tr, f1neg_tr, miou_tr, acc_tr,
                ['{:.2e}'.format(group["lr"]) for group in optimizer.param_groups],
                dt.datetime.now() - t0
                )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    return tr_stats
def validate(model,
             dataset,
             dataloader,
             criterion,
             device,
             stats,
             args,
             folderout=None,
             epoch=0,
             log_file=None,
             name_set="",
             store_on_disc=False,
             store_imgs=False
             ):
    """
    Perform a validation over the validation set. Assumes a batch size of 1.
    (images do not have the same size,
    so we can't stack them in one tensor).
    Validation samples may be large to fit all in the GPU at once.

    Note: criterion is deppmil.criteria.TotalLossEval().
    """
    model.eval()
    metrics = Metrics(threshold=args.final_thres).to(device)
    metrics.eval()

    f1pos_, f1neg_, miou_, acc_ = 0., 0., 0., 0.
    cnt = 0.
    total_loss_ = 0.
    loss_pos_ = 0.
    loss_neg_ = 0.

    mask_fd = None
    name_fd_masks = "masks"  # where to store the predictions.
    name_fd_masks_bin = "masks_bin"  # where to store the bin masks and al.
    if folderout is not None:
        mask_fd = join(folderout, name_fd_masks)
        bin_masks_fd = join(folderout, name_fd_masks_bin)
        if not os.path.exists(mask_fd):
            os.makedirs(mask_fd)

        if not os.path.exists(bin_masks_fd):
            os.makedirs(bin_masks_fd)

    length = len(dataloader)
    t0 = dt.datetime.now()
    myseed = int(os.environ["MYSEED"])

    with torch.no_grad():
        for i, (data, mask, label) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reproducibility.force_seed(myseed + epoch + 1)

            msg = "Expected a batch size of 1. Found `{}`  .... " \
                  "[NOT OK]".format(data.size()[0])
            assert data.size()[0] == 1, msg
            bsz = data.size()[0]

            data = data.to(device)
            labels = label.to(device)
            mask = torch.tensor(mask[0])
            mask_t = mask.unsqueeze(0).to(device)
            assert mask_t.ndim == 4, "ndim = {} must be 4.".format(mask_t.ndim)

            # In validation, we do not need reproducibility since everything
            # is expected to deterministic. Plus,
            # we use only one gpu since the batch size os 1.
            scores_pos, scores_neg, mask_pred, sc_cl_se = model(x=data,
                                                               seed=None
                                                                )
            t_loss, l_p, l_n, l_seg = criterion(scores_pos,
                                                sc_cl_se,
                                                labels,
                                                mask_pred,
                                                scores_neg
                                               )


            mask_pred = mask_pred.squeeze()
            # check sizes of the mask:
            _, _, h, w = mask_t.shape
            hp, wp = mask_pred.shape

            if (h != hp) or (w != wp):  # This means that we have padded the
                # input image. We crop the predicted mask in the center.
                mask_pred = mask_pred[int(hp / 2) - int(h / 2): int(hp / 2) + int(h / 2) + (h % 2),
                                      int(wp / 2) - int(w / 2): int(wp / 2) + int(w / 2) + (w % 2)]

            mask_pred = mask_pred.unsqueeze(0).unsqueeze(0)

            acc, dice_forg, dice_back, miou = metrics(
                scores=scores_pos,
                labels=labels,
                masks_pred=mask_pred.contiguous().view(bsz, -1),
                masks_trg=mask_t.contiguous().view(bsz, -1),
                avg=False
            )

            # tracking
            f1pos_ += dice_forg
            f1neg_ += dice_back
            miou_ += miou
            acc_ += acc
            cnt += bsz
            total_loss_ += t_loss.item()
            loss_pos_ += l_p.item()
            loss_neg_ += l_n.item()

            if (folderout is not None) and store_on_disc:
                # binary mask
                bin_pred_mask = metrics.get_binary_mask(mask_pred).squeeze()
                bin_pred_mask = bin_pred_mask.cpu().detach().numpy().astype(np.bool)
                to_save = {
                    "bin_pred_mask": bin_pred_mask,
                    "dice_forg": dice_forg,
                    "dice_back": dice_back,
                    "i": i
                }

                with open(join(bin_masks_fd, "{}.pkl".format(i)), "wb") as fbin:
                    pkl.dump(to_save, fbin, protocol=pkl.HIGHEST_PROTOCOL)

            if (folderout is not None) and store_imgs and store_on_disc:
                pred_label = int(scores_pos.argmax().item())
                probs = softmax(scores_pos.cpu().detach().numpy())
                prob = float(probs[0, pred_label])

                store_pred_img(i,
                               dataset,
                               bin_pred_mask * 1.,
                               mask_pred.squeeze().cpu().detach().numpy(),
                               dice_forg,
                               dice_back,
                               prob,
                               pred_label,
                               args,
                               mask_fd,
                               )


    # avg
    total_loss_ /= float(cnt)
    loss_pos_ /= float(cnt)
    loss_neg_ /= float(cnt)
    acc_ *= (100. / float(cnt))
    f1pos_ *= (100. / float(cnt))
    f1neg_ *= (100. / float(cnt))
    miou_ *= (100. / float(cnt))

    if stats is not None:
        stats["total_loss"].append(total_loss_)
        stats["loss_pos"].append(loss_pos_)
        stats["loss_neg"].append(loss_neg_)
        stats["acc"].append(acc_)
        stats["f1pos"].append(f1pos_)
        stats["f1neg"].append(f1neg_)
        stats['miou'].append(miou_)

    to_write = "EVAL ({}): TLoss: {:.2f}, L+: {:.2f}, L-: {:.2f}, " \
               "F1+: {:.2f}%, F1-: {:.2f}%, MIOU: {:.2f}%, ACC: {:.2f}%, " \
               "t:{}, epoch {:>2d}.".format(
        name_set,
        total_loss_,
        loss_pos_,
        loss_neg_,
        f1pos_,
        f1neg_,
        miou_,
        acc_,
        dt.datetime.now() - t0,
        epoch
        )
    print(to_write)
    if log_file:
        log(log_file, to_write)


    if folderout is not None:
        msg = "EVAL {}: \n".format(name_set)
        msg += "ACC {}% \n".format(acc_)
        msg += "F1+ {}% \n".format(f1pos_)
        msg += "F1- {}% \n".format(f1neg_)
        msg += "MIOU {}% \n".format(miou_)
        announce_msg(msg)
        if log_file:
            log(log_file, msg)

    if (folderout is not None) and store_on_disc:
        pred = {
            "total_loss": total_loss_,
            "loss_pos": loss_pos_,
            "loss_neg": loss_neg_,
            "acc": acc_,
            "f1pos": f1pos_,
            "f1neg": f1neg_,
            "miou_": miou_
        }
        with open(
                join(folderout, "pred--{}.pkl".format(name_set)), "wb") as fout:
            pkl.dump(pred, fout, protocol=pkl.HIGHEST_PROTOCOL)

        # compress. delete folder.
        cmdx = [
            "cd {} ".format(mask_fd),
            "cd .. ",
           # "tar -cf {}.tar.gz {}".format(name_fd_masks, name_fd_masks),
           # "rm -r {}".format(name_fd_masks)
        ]

        cmdx += [
            "cd {} ".format(bin_masks_fd),
            "cd .. ",
            "tar -cf {}.tar.gz {}".format(name_fd_masks_bin, name_fd_masks_bin),
            "rm -r {}".format(name_fd_masks_bin)
        ]

        cmdx = " && ".join(cmdx)
        print("Running bash-cmds: \n{}".format(cmdx.replace("&& ", "\n")))
        subprocess.run(cmdx, shell=True, check=True)
    else:
        return stats
    def __getitem__(self, index):
        """
        Return one sample and its label and extra information that we need later.

        :param index: int, the index of the sample within the whole dataset.
        :return: sample: pytorch.tensor of size (1, C, H, W) and datatype torch.FloatTensor. Where C is the number of
                 color channels (=3), and H is the height of the patch, and W is its width.
                 mask: PIL.Image.Image, the mask of the regions of interest.
                 label: int, the label of the sample.
        """
        # Force seeding: a workaround to deal with reproducibility when suing different number of workers if want to
        # preserve the reproducibility. Each sample has its won seed.
        reproducibility.force_seed(self.seeds[index])

        if self.set_for_eval:
            error_msg = "Something wrong. You didn't ask to set the data ready for evaluation, but here we are " \
                        ".... [NOT OK]"
            assert self.inputs_ready is not None and self.labels_ready is not None, error_msg
            img = self.inputs_ready[index]
            mask = self.masks_ready[index]
            target = self.labels_ready[index]

            return img, mask, target

        if self.do_not_save_samples:
            img, mask, target = self.load_sample_i(index)
        else:
            assert self.preloaded, "Sorry, you need to preload the data first .... [NOT OK]"
            img, mask, target = self.images[index], self.masks[
                index], self.labels[index]
        # Upscale on the fly. Sorry, this may add an extra time, but, we do not want to save in memory upscaled
        # images!!!! it takes a lot of space, especially for large datasets. So, compromise? upscale only when
        # necessary.
        # check if we need to upscale the image. Useful for Caltech-UCSD-Birds-200-2011.
        if self.up_scale_small_dim_to is not None:
            w, h = img.size
            w_up, h_up = self.get_upscaled_dims(w, h,
                                                self.up_scale_small_dim_to)
            img = img.resize((w_up, h_up), resample=PIL.Image.BILINEAR)

        # Upscale the image: only for Caltech-UCSD-Birds-200-2011.

        if self.randomCropper:  # training only. Do not crop for evaluation.
            # Padding.
            if self.padding_size:
                w, h = img.size
                ph, pw = self.padding_size
                padding = (int(pw * w), int(ph * h))
                img = TF.pad(img,
                             padding=padding,
                             padding_mode=self.padding_mode)
                mask = TF.pad(
                    mask, padding=padding,
                    padding_mode=self.padding_mode)  # just for tracking.

            img, (i, j, h, w) = self.randomCropper(img)
            # print("Dadaloader Index {} i  {}  j {} seed {}".format(index, i, j, self.seeds[index]))
            # crop the mask
            mask = TF.crop(
                mask, i, j, h,
                w)  # just for tracking. Not used for actual training.

        # Pad the image to be div. by 32 in both sides.
        if self.force_div_32:
            w, h = img.size
            pad_left, pad_right = self.get_padding(w, 32)
            pad_top, pad_bottom = self.get_padding(h, 32)
            padding = (pad_left, pad_top, pad_right, pad_bottom)
            img = TF.pad(img, padding=padding, padding_mode="reflect")
            # This is not necessary in training nor in test. It may be necessary during training if your patch size
            # is not dividable by 32 and you want to make it dividable by 32.
            # We are going to comment this.
            # if not self.set_for_eval_backup:  # we want to keep the mask intact for evaluation.
            # just for tracking. Not used for training.
            #    mask = TF.pad(mask, padding=padding, padding_mode="reflect")

        if self.transform_img:  # just for training: do not transform the mask (since it is not used).
            img = self.transform_img(img)

        if self.transform_tensor:  # just for training: do not transform the mask (since it is not used).
            img = self.transform_tensor(img)

        # Prepare the mask to be used on GPU to compute Dice index.
        mask = np.array(mask, dtype=np.float32) / 255.  # full of 0 and 1.
        mask = self.to_tensor(np.expand_dims(
            mask, axis=-1))  # mak the mask with shape (h, w, 1).

        return img, mask, target
        """
        Init. function.
        :param classes: int, number of classes.
        :param modalities: int, number of modalities per class.
        """
        super(ClassWisePooling, self).__init__()

        self.C = classes
        self.M = modalities

    def forward(self, input):
        N, C, H, W = input.size()
        assert C == self.C * self.M, 'Wrong number of channels, expected {} channels but got {}'.format(
            self.C * self.M, C)
        return torch.mean(input.view(N, self.C, self.M, -1),
                          dim=2).view(N, self.C, H, W)

    def __repr__(self):
        return self.__class__.__name__ + '(classes={}, modalities={})'.format(
            self.C, self.M)


if __name__ == "__main__":
    b, c = 10, 2
    reproducibility.force_seed(0)
    funcs = [WildCatPoolDecision(dropout=0.5)]
    x = torch.randn(b, c, 12, 12)
    for func in funcs:
        out = func(x)
        print(func.__class__.__name__, '->', out.size(), out)
Example #9
0
def validate(model,
             dataset,
             dataloader,
             criterion,
             device,
             stats,
             epoch=0,
             callback=None,
             log_file=None,
             name_set=""):
    """
    Perform a validation over the validation set. Assumes a batch size of 1. (images do not have the same size,
    so we can't stack them in one tensor).
    Validation samples may be large to fit all in the GPU at once.

    Note: criterion is deppmil.criteria.TotalLossEval().
    """
    # TODO: [FUTURE] find a way to do final processing within the loop below such as drawing and all that. This will
    #  avoid STORING the results of validating over huge datasets, THEN, process samples one by one!!! it is not
    #  efficient. This should be an option to be activated/deactivated when needed. For instance, during validation,
    #  it is not necessary, while at the end, it is necessary.
    model.eval()

    total_loss, loss_pos, loss_neg = AverageMeter(), AverageMeter(
    ), AverageMeter()
    loss_class_seg, f1pos, f1neg = AverageMeter(), AverageMeter(
    ), AverageMeter()
    errors = AverageMeter()

    length = len(dataloader)
    predictions = np.zeros(length, dtype=int)
    labels = np.zeros(length)
    probs = np.zeros(
        length)  # prob. of the predicted class (using positive region).
    probs_pos = np.zeros(
        (length, model.num_classes))  # prob. over the positive region.
    probs_neg = np.zeros(
        (length, model.num_classes))  # prob. over the negative region.
    masks_pred = []
    t0 = dt.datetime.now()

    with torch.no_grad():
        for i, (data, mask, label) in tqdm.tqdm(enumerate(dataloader),
                                                ncols=80,
                                                total=length):
            # TODO: FUTURE allow parallel validation over mutliple GPUs with samples with different sizes. Make the
            #  batch as a list for validation for instance. It does not make sens to validayte sample by sample while
            #  you are able to run multiple forwards at once. An alternative is to create a particular dataparallel
            #  that loops over a list instead of batch-size dim.
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch)

            assert data.size(
            )[0] == 1, "Expected a batch size of 1. Found `{}`  .... [NOT OK]".format(
                data.size()[0])

            labels[i] = label.item()  # batch size 1.
            data = data.to(device)
            label = label.to(device)

            mask_t = [m.to(device) for m in mask]

            # In validation, we do not need reproducibility since everything is expected to deterministic. Plus,
            # we use only one gpu since the batch size os 1.
            output = model(x=data, seed=None)  # --> out_pos, out_neg, masks

            out_pos = output[0][0][
                0]  # scores: take the first element of the batch.
            out_neg = output[1][0][
                0]  # scores: take the first element of the batch.
            assert out_pos.ndimension() == 1, "We expected only 1 dim. We found {}. Make sure you are using abatch " \
                                              "size of 1. .... [NOT OK]".format(out_pos.ndimension())

            pred_label = out_pos.argmax()
            predictions[i] = int(pred_label.item())

            scores_pos = softmax(out_pos.cpu().detach().numpy())
            scores_neg = softmax(out_neg.cpu().detach().numpy())
            probs_pos[i] = scores_pos
            probs_neg[i] = scores_neg
            probs[i] = scores_pos[predictions[i]]
            mask_pred = torch.squeeze(
                output[2]).cpu().detach().numpy()  # predicted mask.
            # check sizes of the mask:
            _, h, w = mask_t[0].size()
            hp, wp = mask_pred.shape

            if dataset.dataset_name in [
                    "Caltech-UCSD-Birds-200-2011", "Oxford-flowers-102"
            ]:
                # Remove the padding if is there was any. (the only one allowed: force_div_32)
                assert dataset.padding_size is None, "dataset.padding_size is supposed to be None. We do not support" \
                                                     "padding of this type for this dataset."

                # Important: we assume that dataloader of this dataset is not shuffled to make this access using the
                # index (i) correct. If shuffled, the access is not correct.
                w_mask_no_pad_forced, h_mask_no_pad_forced = dataset.original_images_size[
                    i]

                if dataset.force_div_32:
                    # Find the size of the mask without padding.
                    w_up, h_up = dataset.get_upscaled_dims(
                        w_mask_no_pad_forced, h_mask_no_pad_forced,
                        dataset.up_scale_small_dim_to)
                    # Remove the padded parts by cropping at the center.
                    mask_pred = mask_pred[int(hp / 2) -
                                          int(h_up / 2):int(hp / 2) +
                                          int(h_up / 2) + (h_up % 2),
                                          int(wp / 2) -
                                          int(w_up / 2):int(wp / 2) +
                                          int(w_up / 2) + (w_up % 2)]
                    assert mask_pred.shape[0] == h_up, "h_up={}, mask_pred.shape[0]={}. Expected to be the same." \
                                                       "[Not OK]".format(h_up, mask_pred.shape[0])
                    assert mask_pred.shape[1] == w_up, "w_up={}, mask_pred.shape[1]={}. Expected to be the same." \
                                                       "[Not OK]".format(w_up, mask_pred.shape[1])
                # Now, downscale the predicted mask to the size of the true mask. We use
                # torch.nn.functional.interpolate.
                msk_pred_torch = torch.from_numpy(mask_pred).view(
                    1, 1, mask_pred.shape[0], mask_pred.shape[1])
                mask_pred = F.interpolate(
                    msk_pred_torch,
                    size=(h_mask_no_pad_forced, w_mask_no_pad_forced),
                    mode="bilinear",
                    align_corners=True).squeeze().numpy()

                # Now get the correct sizes:
                hp, wp = mask_pred.shape
                assert hp == h, "hp={}, h={} are supposed to be the same! .... [NOT OK]".format(
                    hp, h)
                assert wp == w, "wp={}, w={} are supposed to be the same! .... [NOT OK]".format(
                    wp, w)

            if (h != hp) or (
                    w !=
                    wp):  # This means that we have padded the input image.
                # We crop the predicted mask in the center.
                mask_pred = mask_pred[int(hp / 2) - int(h / 2):int(hp / 2) +
                                      int(h / 2) + (h % 2),
                                      int(wp / 2) - int(w / 2):int(wp / 2) +
                                      int(w / 2) + (w % 2)]
            masks_pred.append(mask_pred)

            loss = criterion(output, label, mask_t)
            t_l, l_p, l_n, l_c_s = loss

            total_loss.append(t_l.item())
            loss_pos.append(l_p.item())
            loss_neg.append(l_n.item())
            loss_class_seg.append(l_c_s.item())
            # We no longer compute the dice in the losseval, since masks need some transformations before being ready
            # for dice computations. Such transformations are dataset-dependent and we do not want to crowd the eval
            # loss. It must be clean. Now, Dice is computed over CPU using a dice function.

            x1 = ((np.ravel(mask[0]) >= 0.5) * 1.).astype(np.float32)
            x2 = ((np.ravel(mask_pred) >= 0.5) * 1.).astype(np.float32)
            # Since F1 and Dice index are the same over binary data, and, for computation time reasons (Dice index is
            # way faster than F1 in ter of speed), we decided to call Dice function.
            # Note: in practice, there maybe a difference in ter of precision (e.g., 1e-7).
            f1pos.append(compute_dice_index(x1, x2))
            f1neg.append(compute_dice_index(1. - x1, 1. - x2))

            errors.append((pred_label != label).item())

    if callback:
        callback.scalar('Val_loss', epoch + 1, total_loss.avg)
        callback.scalar('Val_error', epoch + 1, errors.avg)

    to_write = ">>>>>>>>>>>>>>>>>> Total L.avg: {:.5f}, Pos.L.avg: {:.5f}, Neg.L.avg: {:.5f}, " \
               "Cl.Seg.L.avg: {:.5f}, F1+: {:.5f}, F1-: {:.5f}, Error.avg: {:.2f}, t:{}, Eval {} epoch {:>2d}, " \
               "".format(
                total_loss.avg, loss_pos.avg, loss_neg.avg, loss_class_seg.avg,
                f1pos.avg, f1neg.avg, errors.avg * 100, dt.datetime.now() - t0, name_set, epoch
                )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # Update stats
    stats["total_loss"] = np.append(stats["total_loss"],
                                    np.array(total_loss.values))
    stats["loss_pos"] = np.append(stats["loss_pos"], np.array(loss_pos.values))
    stats["loss_neg"] = np.append(stats["loss_neg"], np.array(loss_neg.values))
    stats["loss_class_seg"] = np.append(stats["loss_class_seg"],
                                        np.array(loss_class_seg.values))
    stats["errors"] = np.append(stats["errors"],
                                np.array(errors.values).mean() * 100)
    stats["f1pos"] = np.append(stats["f1pos"],
                               np.array(f1pos.values).mean() * 100)
    stats["f1neg"] = np.append(stats["f1neg"],
                               np.array(f1neg.values).mean() * 100)
    pred = {
        "predictions": predictions,
        "labels": labels,
        "probs": probs,
        "masks": masks_pred,
        "probs_pos": probs_pos,
        "probs_neg": probs_neg
    }

    # Collect stats from only this epoch. (can be useful to plot distributions since we lose the actual stats due to
    # the above update!!!!)
    stats_now = {
        "total_loss": np.array(total_loss.values),
        "loss_pos": np.array(loss_pos.values),
        "loss_neg": np.array(loss_neg.values),
        "loss_class_seg": np.array(loss_class_seg.values),
        "errors": np.array(errors.values).mean() * 100,
        "f1pos": np.array(f1pos.values),
        "f1neg": np.array(f1neg.values)
    }

    return stats, stats_now, pred
Example #10
0
def train_one_epoch(model,
                    optimizer,
                    dataloader,
                    criterion,
                    device,
                    tr_stats,
                    epoch=0,
                    callback=None,
                    log_file=None,
                    ALLOW_MULTIGPUS=False,
                    NBRGPUS=1):
    """
    Perform one epoch of training.
    :param model:
    :param optimizer:
    :param dataloader:
    :param criterion:
    :param device:
    :param epoch:
    :param callback:
    :param log_file:
    :param ALLOW_MULTIGPUS: bool. If True, we are in multiGPU mode.
    :return:
    """
    model.train()

    total_loss, loss_pos, loss_neg = AverageMeter(), AverageMeter(
    ), AverageMeter()
    loss_class_seg, errors = AverageMeter(), AverageMeter()

    length = len(dataloader)
    t0 = dt.datetime.now()

    for i, (data, masks, labels) in tqdm.tqdm(enumerate(dataloader),
                                              ncols=80,
                                              total=length):
        reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch)

        data = data.to(device)
        labels = labels.to(device)

        model.zero_grad()

        t_l, l_p, l_n, l_c_s = 0., 0., 0., 0.
        prngs_cuda = None  # TODO: crack in optimal code.

        # Optimization:
        if model.nbr_times_erase == 0:  # no erasing.
            if not ALLOW_MULTIGPUS:
                if "CC_CLUSTER" in os.environ.keys(
                ):  # TODO: crack in optimal code.
                    assert NBRGPUS <= 1, "Something wrong. You deactivated multigpu mode, but we find {} GPUs. " \
                                         "This will " \
                                         "not guarantee reproducibility. We do not know why you did that. Exiting " \
                                         ".... [NOT OK]".format(NBRGPUS)
                seeds_threads = None
            else:
                assert NBRGPUS > 1, "Something is wrong. You asked for multigpu mode. But, we found {} GPUs. Exiting " \
                                    ".... [NOT OK]".format(NBRGPUS)
                # The seeds are generated randomly before calling the threads.
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.
                seeds_threads = torch.randint(0,
                                              np.iinfo(np.uint32).max + 1,
                                              (NBRGPUS, )).to(device)
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.
                prngs_cuda = []
                # Create different prng states of cuda before forking.
                for seed in seeds_threads:
                    # get the corresponding state of the cuda prng with respect to the seed.
                    torch.manual_seed(seed)
                    torch.cuda.manual_seed(seed)
                    prngs_cuda.append(torch.cuda.get_rng_state())
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.

            if prngs_cuda is not None and prngs_cuda != []:  # TODO: crack in optimal code.
                prngs_cuda = torch.stack(prngs_cuda)
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.
            output = model(x=data, seed=seeds_threads,
                           prngs_cuda=prngs_cuda)  # --> out_pos, out_neg,
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.
            # masks
            _, _, _, scores_seg, _ = output
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.
            l_c_s = criterion.loss_class_head_seg(scores_seg, labels)
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.

            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.
            loss = criterion(output, labels)
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.
            t_l, l_p, l_n = loss
            # print("\t \t \t \t \t {} \t {} \t {}".format(t_l.item(), l_p.item(), l_n.item()))
            t_l = t_l + l_c_s
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.
            t_l.backward()
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.
            # torch.cuda.set_rng_state_all(prng_state_all)
            # torch.cuda.set_rng_state(prng_state)
        else:  # we need to erase some times.
            # Compute the cumulative mask.
            l_c_s = 0.
            l_c_s_per_sample = None
            m_pos = None
            data_safe = torch.zeros_like(data)
            data_safe = data_safe.copy_(data)
            history = torch.ones_like(labels).type(
                torch.float)  # init. history tracker coefs. to 1.
            # if the model predicts the wrong label, we set forever the trust for this sample to 0.
            # for er in range(model.nbr_times_erase + 1):
            er = 0
            while history.sum() > 0:

                prngs_cuda = None  # TODO: crack in optimal code.
                if er >= model.nbr_times_erase:  # if we exceed the maximum, stop looping. We are not looping
                    # forever!! aren't we?
                    break
                if not ALLOW_MULTIGPUS:
                    if "CC_CLUSTER" in os.environ.keys(
                    ):  # TODO: crack in optimal code.
                        assert NBRGPUS <= 1, "Something wrong. You deactivated multigpu mode, but we find {} GPUs." \
                                             " This " \
                                             "will not guarantee reproducibility. We do not know why you did that. " \
                                             "Exiting .... [NOT OK]".format(NBRGPUS)
                    seeds_threads = None
                else:
                    assert NBRGPUS > 1, "Something is wrong. You asked for multigpu mode. But, we found {} GPUs. " \
                                        "Exiting .... [NOT OK]".format(NBRGPUS)
                    # The seeds are generated randomly before calling the threads.
                    reproducibility.force_seed(
                        int(os.environ["MYSEED"]) + epoch + i)  # armor.
                    seeds_threads = torch.randint(0,
                                                  np.iinfo(np.uint32).max + 1,
                                                  (NBRGPUS, )).to(device)
                    reproducibility.force_seed(
                        int(os.environ["MYSEED"]) + epoch + i)  # armor.
                    prngs_cuda = []
                    # Create different prng states of cuda before forking.
                    for seed in seeds_threads:
                        # get the corresponding state of the cuda prng with respect to the seed.
                        torch.manual_seed(seed)
                        torch.cuda.manual_seed(seed)
                        prngs_cuda.append(torch.cuda.get_rng_state())
                    reproducibility.force_seed(
                        int(os.environ["MYSEED"]) + epoch + i)  # armor.

                if prngs_cuda is not None and prngs_cuda != []:  # TODO: crack in optimal code.
                    prngs_cuda = torch.stack(prngs_cuda)
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.
                mask, scores_seg, _ = model(
                    x=data,
                    code="segment",
                    seed=seeds_threads,
                    prngs_cuda=prngs_cuda)  # model.segment(data) mask is
                # detached! (mask is continuous)
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.
                mask, _, _ = model(x=data,
                                   code="get_mask_xpos_xneg",
                                   mask_c=mask,
                                   seed=seeds_threads,
                                   prngs_cuda=prngs_cuda)  #
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.
                # model.get_mask_xpos_xneg(data, mask)  # mask = M+.

                probs_seg = criterion.softmax(scores_seg)
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.

                l_c_s_tmp = criterion.loss_class_head_seg_red_none(
                    scores_seg, labels)
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.
                l_c_s_tmp.mean().backward()
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.
                # print("Loop: \t \t \t \t \t {}".format(l_c_s_tmp.mean().item()))

                # avoid maintaining the previous graph. Therefore,
                # cut the dependency.
                l_c_s_tmp = l_c_s_tmp.detach()
                probs_seg = probs_seg.detach()

                l_c_s += l_c_s_tmp.mean()  # for tracking only.
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.

                # Update the mask (m_pos: M+): The negative mask is expected to contain all the non-discriminative
                # regions. However, it may still contain some discriminative parts. In order to find them, we apply M-
                # over the input image (erase the found discriminative parts) and try to localize NEW discriminative
                # regions.

                if er == 0:  # if the first time, create the tracking mask.
                    m_pos = torch.zeros_like(mask)
                    l_c_s_per_sample = torch.zeros_like(l_c_s_tmp)
                    l_c_s_per_sample = l_c_s_per_sample.copy_(l_c_s_tmp)

                trust = torch.ones_like(labels).type(
                    torch.float)  # init. trust coefs. to 1.
                p_y, pred = torch.max(probs_seg, dim=1)

                # overall trust:
                decay = np.exp(-float(er) / model.sigma_erase)
                # per-sample trust:
                check_loss = (l_c_s_tmp <= l_c_s_per_sample).type(torch.float)
                check_label = (pred == labels).type(torch.float)

                trust *= decay * check_label * check_loss * p_y

                # Update the history
                history = torch.min(check_label, history)

                # Apply the history to the trust:
                trust *= history

                trust = trust.view(trust.size()[0], 1, 1, 1)
                m_pos_tmp = trust * mask
                m_pos = torch.max(m_pos, m_pos_tmp)  # accumulate the masks.
                # Apply the cumulative negative mask over the image
                data = data * (1 - m_pos) * (trust != 0).type(
                    torch.float) + data * (trust == 0).type(torch.float)
                er += 1
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.

            l_c_s /= (model.nbr_times_erase + 1)

            # Now: m_neg contains the smallest negative area == largest positive area.
            # compute x_pos, x_neg
            m_neg = 1 - m_pos
            x_neg = data_safe * m_neg
            x_pos = data_safe * m_pos

            prngs_cuda = None  # TODO: crack in optimal code.

            # Classify
            if not ALLOW_MULTIGPUS:
                if "CC_CLUSTER" in os.environ.keys(
                ):  # TODO: crack in optimal code.
                    assert NBRGPUS == 1, "Something wrong. You deactivated multigpu mode, but we find {} GPUs. This " \
                                         "will not guarantee reproducibility. We do not know why you did that. " \
                                         "Exiting .... [NOT OK]".format(NBRGPUS)
                seeds_threads = None
            else:
                assert NBRGPUS > 1, "Something is wrong. You asked for multigpu mode. But, we found {} GPUs. " \
                                    "Exiting .... [NOT OK]".format(NBRGPUS)
                # The seeds are generated randomly before calling the threads.
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.
                seeds_threads = torch.randint(0,
                                              np.iinfo(np.uint32).max + 1,
                                              (NBRGPUS, )).to(device)
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.
                prngs_cuda = []
                # Create different prng states of cuda before forking.
                for seed in seeds_threads:
                    # get the corresponding state of the cuda prng with respect to the seed.
                    torch.manual_seed(seed)
                    torch.cuda.manual_seed(seed)
                    prngs_cuda.append(torch.cuda.get_rng_state())
                reproducibility.force_seed(
                    int(os.environ["MYSEED"]) + epoch + i)  # armor.

            if prngs_cuda is not None and prngs_cuda != []:  # TODO: crack in optimal code.
                prngs_cuda = torch.stack(prngs_cuda)
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.
            out_pos = model(x=x_pos,
                            code="classify",
                            seed=seeds_threads,
                            prngs_cuda=prngs_cuda)  # model.classify(x_pos)
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.
            out_neg = model(x=x_neg,
                            code="classify",
                            seed=seeds_threads,
                            prngs_cuda=prngs_cuda)  # model.classify(x_neg)
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.

            output = out_pos, out_neg, None, None, None
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.
            loss = criterion(output, labels)

            t_l, l_p, l_n = loss
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.
            t_l.backward()
            t_l += l_c_s
            # print("ERASE: \t \t \t \t \t {} \t {} \t {}".format(t_l.item(), l_p.item(), l_n.item()))
            reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                       i)  # armor.

        # Update params.
        reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                   i)  # armor.
        optimizer.step()
        reproducibility.force_seed(int(os.environ["MYSEED"]) + epoch +
                                   i)  # armor.
        # End optimization.

        total_loss.append(t_l.item())
        loss_pos.append(l_p.item())
        loss_neg.append(l_n.item())
        loss_class_seg.append(l_c_s.item())

        errors.append(
            (output[0][0].argmax(dim=1) != labels).float().mean().item() *
            100.)  # error over the minibatch.

        if callback and ((i + 1) % callback.fre == 0 or (i + 1) == length):
            callback.scalar("Train_loss", i / length + epoch,
                            total_loss.last_avg)
            callback.scalar("Train_error", i / length + epoch, errors.lat_avg)

    to_write = "Train epoch {:>2d}: Total L.avg: {:.5f}, Pos.L.avg: {:.5f}, Neg.L.avg: {:.5f}, " \
               "Cl.Seg.L.avg: {:.5f}, LR {}, t:{}".format(
                epoch, total_loss.avg, loss_pos.avg, loss_neg.avg, loss_class_seg.avg,
                ['{:.2e}'.format(group["lr"]) for group in optimizer.param_groups], dt.datetime.now() - t0
                )
    print(to_write)
    if log_file:
        log(log_file, to_write)

    # Update stats:
    tr_stats["total_loss"] = np.append(tr_stats["total_loss"],
                                       np.array(total_loss.values))
    tr_stats["loss_pos"] = np.append(tr_stats["loss_pos"],
                                     np.array(loss_pos.values))
    tr_stats["loss_neg"] = np.append(tr_stats["loss_neg"],
                                     np.array(loss_neg.values))
    tr_stats["loss_class_seg"] = np.append(tr_stats["loss_class_seg"],
                                           np.array(loss_class_seg.values))
    tr_stats["errors"] = np.append(tr_stats["errors"], np.array(errors.values))
    return tr_stats
def validate(model,
             dataset,
             dataloader,
             criterion,
             device,
             stats,
             args,
             folderout=None,
             epoch=0,
             log_file=None,
             name_set="",
             store_on_disc=False,
             store_imgs=False,
             final_mode=False,
             seg_threshold=None
             ):
    """
    Perform a validation over the validation set. Assumes a batch size of 1.
    (images do not have the same size,
    so we can't stack them in one tensor).
    Validation samples may be large to fit all in the GPU at once.

    Note: criterion is deppmil.criteria.TotalLossEval().
    """
    model.eval()
    fl_thres = seg_threshold if seg_threshold is not None else args.final_thres

    metrics = Metrics(threshold=fl_thres).to(device)
    dice_f = Dice()
    iou_f = IOU()
    metrics.eval()

    f1pos_, f1neg_, miou_, acc_ = 0., 0., 0., 0.
    l_f1pos_ = []
    l_f1neg_ = []
    l_miou_ = []

    cnt = 0.
    total_loss_ = 0.
    loss_pos_ = 0.
    loss_neg_ = 0.

    # camelyon16
    sizes_m = {
        'm_pred': [],  # metastatic.
        'm_true': [],  # metastatic
        'n_pred': [],  # normal
        'n_true': []  # normal
    }

    mask_fd = None
    name_fd_masks = "masks"  # where to store the predictions.
    name_fd_masks_bin = "masks_bin"  # where to store the bin masks and al.
    if folderout is not None:
        mask_fd = join(folderout, name_fd_masks)
        bin_masks_fd = join(folderout, name_fd_masks_bin)
        if not os.path.exists(mask_fd):
            os.makedirs(mask_fd)

        if not os.path.exists(bin_masks_fd):
            os.makedirs(bin_masks_fd)

    length = len(dataloader)
    t0 = dt.datetime.now()
    myseed = int(os.environ["MYSEED"])
    masks_sizes = []
    cancer_scores = []
    glabels = []

    avg_forward_t = dt.timedelta(0)

    with torch.no_grad():
        for i, (data, mask, label) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reproducibility.force_seed(myseed + epoch + 1)

            msg = "Expected a batch size of 1. Found `{}`  .... " \
                  "[NOT OK]".format(data.size()[0])
            assert data.size()[0] == 1, msg
            bsz = data.size()[0]

            data = data.to(device)
            labels = label.to(device)
            mask = mask[0].clone()
            mask_t = mask.unsqueeze(0).to(device)
            assert mask_t.ndim == 4, "ndim = {} must be 4.".format(mask_t.ndim)

            # In validation, we do not need reproducibility since everything
            # is expected to deterministic. Plus,
            # we use only one gpu since the batch size os 1.
            t0 = dt.datetime.now()
            scores_pos, scores_neg, mask_pred, sc_cl_se, cams = model(
                x=data, glabels=labels, seed=None)
            delta_t = dt.datetime.now() - t0
            avg_forward_t += delta_t

            t_loss, l_p, l_n, l_seg = criterion(
                scores_pos,
                sc_cl_se,
                labels,
                mask_pred,
                scores_neg,
                cams
            )


            mask_pred = mask_pred.squeeze()
            # check sizes of the mask:
            _, _, h, w = mask_t.shape
            hp, wp = mask_pred.shape

            assert args.padding_size in [None, 'None', 0.0], args.padding_size
            mask_pred = F.interpolate(
                input=mask_pred.unsqueeze(0).unsqueeze(0),
                size=(h, w),
                mode='bilinear',
                align_corners=True
            ).squeeze()

            mask_pred = mask_pred.unsqueeze(0).unsqueeze(0)

            if args.dataset == constants.GLAS:
                cl_scores = scores_pos
                acc, dice_forg, dice_back, miou = metrics(
                    scores=cl_scores,
                    labels=labels,
                    masks_pred=mask_pred.contiguous().view(bsz, -1),
                    masks_trg=mask_t.contiguous().view(bsz, -1),
                    avg=False
                )
                f1pos_ += dice_forg
                f1neg_ += dice_back
                miou_ += miou

            elif args.dataset == constants.CAMELYON16P512:
                cl_scores = sc_cl_se  # sc_cl_se, scores_pos

                plabels = metrics.predict_label(cl_scores)
                acc = ((plabels - labels) == 0.).float().sum()
                assert data.size()[0] == 1
                ppixels = metrics.get_binary_mask(
                    mask_pred.contiguous().view(bsz, -1), metrics.threshold)
                masks_trg = mask_t.contiguous().view(bsz, -1)
                dice_forg = 0.
                dice_back = 0.
                miou = 0.

                glabels.append(labels.item())
                masks_sizes.append(ppixels.mean().item())

                if labels.item() == 1:
                    sizes_m['m_pred'].append(ppixels.mean().item() * 100.)
                    sizes_m['m_true'].append(masks_trg.mean().item() * 100.)
                elif labels.item() == 0:
                    sizes_m['n_pred'].append(ppixels.mean().item() * 100.)
                    sizes_m['n_true'].append(0.0)
                else:
                    raise ValueError

                cancer_scores.append(
                    torch.softmax(cl_scores, dim=1)[0, 1].item())

                if masks_trg.sum() > 0:
                    l_f1pos_.append(dice_f(ppixels.view(bsz, -1),
                                    masks_trg.view(bsz, -1)))
                    dice_forg = l_f1pos_[-1]

                    if (1. - masks_trg).sum() > 0:
                        tmp = iou_f(ppixels.view(1, -1),
                                    masks_trg.view(1, -1))
                        tmp = tmp + iou_f(1. - ppixels.view(1, -1),
                                          1. - masks_trg.view(1, -1))
                        l_miou_.append(tmp / 2.)
                        miou = l_miou_[-1]

                if (1. - masks_trg).sum() > 0:
                    l_f1neg_.append(dice_f(1. - ppixels.view(bsz, -1),
                                           1. - masks_trg.view(bsz, -1)))
                    dice_back = l_f1neg_[-1]

            else:
                raise NotImplementedError

            acc_ += acc
            cnt += bsz
            total_loss_ += t_loss.item()
            loss_pos_ += l_p.item()
            loss_neg_ += l_n.item()

            if (folderout is not None) and store_on_disc:
                # binary mask
                bin_pred_mask = metrics.get_binary_mask(mask_pred).squeeze()
                bin_pred_mask = bin_pred_mask.cpu().detach().numpy().astype(np.bool)
                if args.dataset == constants.GLAS:
                    to_save = {
                        "bin_pred_mask": bin_pred_mask,
                        "continuous_mask": mask_pred.cpu().detach().numpy(),
                        "dice_forg": dice_forg,
                        "dice_back": dice_back,
                        "i": i
                    }
                elif args.dataset == constants.CAMELYON16P512:
                    to_save = {
                        "bin_pred_mask": bin_pred_mask,
                        "dice_forg": dice_forg,
                        "dice_back": dice_back,
                        "i": i
                    }
                else:
                    raise NotImplementedError

                with open(join(bin_masks_fd, "{}.pkl".format(i)), "wb") as fbin:
                    pkl.dump(to_save, fbin, protocol=pkl.HIGHEST_PROTOCOL)

            if (folderout is not None) and store_imgs and store_on_disc:
                pred_label = int(cl_scores.argmax().item())
                probs = softmax(cl_scores.cpu().detach().numpy())
                prob = float(probs[0, pred_label])

                store_pred_img(i,
                               dataset,
                               bin_pred_mask * 1.,
                               mask_pred.squeeze().cpu().detach().numpy(),
                               dice_forg,
                               dice_back,
                               prob,
                               pred_label,
                               args,
                               mask_fd,
                               )

    # avg
    acc_ *= (100. / float(cnt))
    if args.dataset == constants.CAMELYON16P512:
        with open(join(folderout,
                           'log-cl-{}-final-{}.txt'.format(epoch,
                                                           final_mode)),
                  'a') as fz:
            fz.write("\nClassification accuracy: {} (%)".format(acc_))

        with open(join(folderout,
                       'log-seg-{}-final-{}.txt'.format(
                           epoch, final_mode)), 'a') as fz:
            fz.write("\nClassification accuracy: {} (%)".format(acc_))

    total_loss_ /= float(cnt)
    loss_pos_ /= float(cnt)
    loss_neg_ /= float(cnt)
    avg_forward_t /= float(cnt)

    if args.dataset == constants.GLAS:
        f1pos_ *= (100. / float(cnt))
        f1neg_ *= (100. / float(cnt))
        miou_ *= (100. / float(cnt))
    elif args.dataset == constants.CAMELYON16P512:
        f1pos_ = 0.
        f1neg_ = 0.
        miou_ = 0.
        if l_f1pos_:
            f1pos_ = torch.stack(l_f1pos_).mean() * 100.
        if l_f1neg_:
            f1neg_ = torch.stack(l_f1neg_).mean() * 100.
        if l_miou_:
            miou_ = torch.stack(l_miou_).mean() * 100.
    else:
        raise NotImplementedError

    if stats is not None:
        stats["total_loss"].append(total_loss_)
        stats["loss_pos"].append(loss_pos_)
        stats["loss_neg"].append(loss_neg_)
        stats["acc"].append(acc_)
        stats["f1pos"].append(f1pos_)
        stats["f1neg"].append(f1neg_)
        stats['miou'].append(miou_)

    to_write = "EVAL ({}): TLoss: {:.2f}, L+: {:.2f}, L-: {:.2f}, " \
               "F1+: {:.2f}%, F1-: {:.2f}%, MIOU: {:.2f}%, ACC: {:.2f}%, " \
               "t:{}, epoch {:>2d}.".format(
        name_set,
        total_loss_,
        loss_pos_,
        loss_neg_,
        f1pos_,
        f1neg_,
        miou_,
        acc_,
        dt.datetime.now() - t0,
        epoch
        )

    print(to_write)
    if log_file:
        log(log_file, to_write)

    if final_mode:
        assert folderout is not None
        msg = "EVAL {}: \n".format(name_set)
        msg += "ACC {}% \n".format(acc_)
        msg += "F1+ {}% \n".format(f1pos_)
        msg += "F1- {}% \n".format(f1neg_)
        msg += "MIOU {}% \n".format(miou_)
        announce_msg(msg)
        if log_file:
            log(log_file, msg)

        with open(join(folderout, 'avg_forward_time.txt'), 'w') as fend:
            fend.write("Model: {}. \n Average forward time (eval mode): "
                       " {}.".format(args.model['model_name'], avg_forward_t
            ))

    if (folderout is not None) and store_on_disc:
        pred = {
            "total_loss": total_loss_,
            "loss_pos": loss_pos_,
            "loss_neg": loss_neg_,
            "acc": acc_,
            "f1pos": f1pos_,
            "f1neg": f1neg_,
            "miou_": miou_
        }
        with open(
                join(folderout, "pred--{}.pkl".format(name_set)), "wb") as fout:
            pkl.dump(pred, fout, protocol=pkl.HIGHEST_PROTOCOL)

        # compress. delete folder.
        cmdx = [
            "cd {} ".format(mask_fd),
            "cd .. ",
           # "tar -cf {}.tar.gz {}".format(name_fd_masks, name_fd_masks),
           # "rm -r {}".format(name_fd_masks)
        ]

        cmdx += [
            "cd {} ".format(bin_masks_fd),
            "cd .. ",
            "tar -cf {}.tar.gz {}".format(name_fd_masks_bin, name_fd_masks_bin),
            "rm -r {}".format(name_fd_masks_bin)
        ]

        cmdx = " && ".join(cmdx)
        print("Running bash-cmds: \n{}".format(cmdx.replace("&& ", "\n")))
        try:
            subprocess.run(cmdx, shell=True, check=True)
        except subprocess.SubprocessError as e:
            print("Failed to run: {}. Error: {}".format(cmdx, e))

    else:
        return stats