Пример #1
0
def init_seed(seed=None):
    """
    * initialize the seed.
    * Set a seed to some modules for reproducibility.

    Note:

    While this attempts to ensure reproducibility, it does not offer an
    absolute guarantee. The results may be similar to some precision.
    Also, they may be different due to an amplification to extremely
    small differences.

    See:

    https://pytorch.org/docs/stable/notes/randomness.html
    https://stackoverflow.com/questions/50744565/
    how-to-handle-non-determinism-when-training-on-a-gpu

    :param seed: int, a seed. Default is None: use the default seed (0).
    :return:
    """
    if seed is None:
        seed = get_seed()
    else:
        os.environ["MYSEED"] = str(seed)
        announce_msg("SEED: {} ".format(os.environ["MYSEED"]))

    check_if_allow_multgpu_mode()
    reset_seed(seed)
Пример #2
0
    def __init__(self, inplans, modalities, num_classes, kmax=0.5, kmin=None,
                 alpha=0.6, dropout=0.0, tau=1.):
        """

        :param inplans:
        :param modalities:
        :param num_classes:
        :param kmax:
        :param kmin:
        :param alpha:
        :param dropout:
        :param tau: float > 0.
        """
        announce_msg("Using Poisson hard-wired output.")
        super(PoissonHead, self).__init__(
            inplans=inplans, modalities=modalities, num_classes=num_classes,
            kmax=kmax, kmin=kmin, alpha=alpha, dropout=dropout)

        msg = "`tau` should be float. found {} ...[NOT OK]".format(
            type(tau))
        assert isinstance(tau, float), msg
        msg = "`tau` must be in ]0., inf[. found {} ... [NOT OK]".format(tau)
        assert tau > 0., msg

        self.tau = tau
        self.pool_to_one = nn.Sequential(
            nn.Linear(num_classes, 1, bias=True),
            nn.Softplus()
        )
Пример #3
0
    log(results_log, "\n\n ########### Results #########\n\n")

    callback = None

    # ==========================================================
    # Data transformations: on PIL.Image.Image and torch.tensor.
    # ==========================================================

    train_transform_img = get_train_transforms_img(args)
    transform_tensor = get_transforms_tensor(args)

    # =======================================================================================================
    # Datasets: create folds, load csv, preprocess files and save on disc, load datasets: train, valid, test.
    # =======================================================================================================

    announce_msg("SPLIT: {} \t FOLD: {}".format(args.split, args.fold))

    relative_fold_path = join(args.fold_folder, args.dataset,
                              "split_" + str(args.split),
                              "fold_" + str(args.fold))
    if isinstance(args.name_classes, str):  # path
        path_classes = join(relative_fold_path, args.name_classes)
        assert os.path.isfile(
            path_classes), "File {} does not exist .... [NOT OK]".format(
                path_classes)
        with open(path_classes, "r") as fin:
            args.name_classes = yaml.load(fin)

    train_csv = join(
        relative_fold_path,
        "train_s_" + str(args.split) + "_f_" + str(args.fold) + ".csv")
def validate(model,
             dataset,
             dataloader,
             criterion,
             device,
             stats,
             args,
             folderout=None,
             epoch=0,
             log_file=None,
             name_set="",
             store_on_disc=False,
             store_imgs=False
             ):
    """
    Perform a validation over the validation set. Assumes a batch size of 1.
    (images do not have the same size,
    so we can't stack them in one tensor).
    Validation samples may be large to fit all in the GPU at once.

    Note: criterion is deppmil.criteria.TotalLossEval().
    """
    model.eval()
    metrics = Metrics(threshold=args.final_thres).to(device)
    metrics.eval()

    f1pos_, f1neg_, miou_, acc_ = 0., 0., 0., 0.
    cnt = 0.
    total_loss_ = 0.
    loss_pos_ = 0.
    loss_neg_ = 0.

    mask_fd = None
    name_fd_masks = "masks"  # where to store the predictions.
    name_fd_masks_bin = "masks_bin"  # where to store the bin masks and al.
    if folderout is not None:
        mask_fd = join(folderout, name_fd_masks)
        bin_masks_fd = join(folderout, name_fd_masks_bin)
        if not os.path.exists(mask_fd):
            os.makedirs(mask_fd)

        if not os.path.exists(bin_masks_fd):
            os.makedirs(bin_masks_fd)

    length = len(dataloader)
    t0 = dt.datetime.now()
    myseed = int(os.environ["MYSEED"])

    with torch.no_grad():
        for i, (data, mask, label) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reproducibility.force_seed(myseed + epoch + 1)

            msg = "Expected a batch size of 1. Found `{}`  .... " \
                  "[NOT OK]".format(data.size()[0])
            assert data.size()[0] == 1, msg
            bsz = data.size()[0]

            data = data.to(device)
            labels = label.to(device)
            mask = torch.tensor(mask[0])
            mask_t = mask.unsqueeze(0).to(device)
            assert mask_t.ndim == 4, "ndim = {} must be 4.".format(mask_t.ndim)

            # In validation, we do not need reproducibility since everything
            # is expected to deterministic. Plus,
            # we use only one gpu since the batch size os 1.
            scores_pos, scores_neg, mask_pred, sc_cl_se = model(x=data,
                                                               seed=None
                                                                )
            t_loss, l_p, l_n, l_seg = criterion(scores_pos,
                                                sc_cl_se,
                                                labels,
                                                mask_pred,
                                                scores_neg
                                               )


            mask_pred = mask_pred.squeeze()
            # check sizes of the mask:
            _, _, h, w = mask_t.shape
            hp, wp = mask_pred.shape

            if (h != hp) or (w != wp):  # This means that we have padded the
                # input image. We crop the predicted mask in the center.
                mask_pred = mask_pred[int(hp / 2) - int(h / 2): int(hp / 2) + int(h / 2) + (h % 2),
                                      int(wp / 2) - int(w / 2): int(wp / 2) + int(w / 2) + (w % 2)]

            mask_pred = mask_pred.unsqueeze(0).unsqueeze(0)

            acc, dice_forg, dice_back, miou = metrics(
                scores=scores_pos,
                labels=labels,
                masks_pred=mask_pred.contiguous().view(bsz, -1),
                masks_trg=mask_t.contiguous().view(bsz, -1),
                avg=False
            )

            # tracking
            f1pos_ += dice_forg
            f1neg_ += dice_back
            miou_ += miou
            acc_ += acc
            cnt += bsz
            total_loss_ += t_loss.item()
            loss_pos_ += l_p.item()
            loss_neg_ += l_n.item()

            if (folderout is not None) and store_on_disc:
                # binary mask
                bin_pred_mask = metrics.get_binary_mask(mask_pred).squeeze()
                bin_pred_mask = bin_pred_mask.cpu().detach().numpy().astype(np.bool)
                to_save = {
                    "bin_pred_mask": bin_pred_mask,
                    "dice_forg": dice_forg,
                    "dice_back": dice_back,
                    "i": i
                }

                with open(join(bin_masks_fd, "{}.pkl".format(i)), "wb") as fbin:
                    pkl.dump(to_save, fbin, protocol=pkl.HIGHEST_PROTOCOL)

            if (folderout is not None) and store_imgs and store_on_disc:
                pred_label = int(scores_pos.argmax().item())
                probs = softmax(scores_pos.cpu().detach().numpy())
                prob = float(probs[0, pred_label])

                store_pred_img(i,
                               dataset,
                               bin_pred_mask * 1.,
                               mask_pred.squeeze().cpu().detach().numpy(),
                               dice_forg,
                               dice_back,
                               prob,
                               pred_label,
                               args,
                               mask_fd,
                               )


    # avg
    total_loss_ /= float(cnt)
    loss_pos_ /= float(cnt)
    loss_neg_ /= float(cnt)
    acc_ *= (100. / float(cnt))
    f1pos_ *= (100. / float(cnt))
    f1neg_ *= (100. / float(cnt))
    miou_ *= (100. / float(cnt))

    if stats is not None:
        stats["total_loss"].append(total_loss_)
        stats["loss_pos"].append(loss_pos_)
        stats["loss_neg"].append(loss_neg_)
        stats["acc"].append(acc_)
        stats["f1pos"].append(f1pos_)
        stats["f1neg"].append(f1neg_)
        stats['miou'].append(miou_)

    to_write = "EVAL ({}): TLoss: {:.2f}, L+: {:.2f}, L-: {:.2f}, " \
               "F1+: {:.2f}%, F1-: {:.2f}%, MIOU: {:.2f}%, ACC: {:.2f}%, " \
               "t:{}, epoch {:>2d}.".format(
        name_set,
        total_loss_,
        loss_pos_,
        loss_neg_,
        f1pos_,
        f1neg_,
        miou_,
        acc_,
        dt.datetime.now() - t0,
        epoch
        )
    print(to_write)
    if log_file:
        log(log_file, to_write)


    if folderout is not None:
        msg = "EVAL {}: \n".format(name_set)
        msg += "ACC {}% \n".format(acc_)
        msg += "F1+ {}% \n".format(f1pos_)
        msg += "F1- {}% \n".format(f1neg_)
        msg += "MIOU {}% \n".format(miou_)
        announce_msg(msg)
        if log_file:
            log(log_file, msg)

    if (folderout is not None) and store_on_disc:
        pred = {
            "total_loss": total_loss_,
            "loss_pos": loss_pos_,
            "loss_neg": loss_neg_,
            "acc": acc_,
            "f1pos": f1pos_,
            "f1neg": f1neg_,
            "miou_": miou_
        }
        with open(
                join(folderout, "pred--{}.pkl".format(name_set)), "wb") as fout:
            pkl.dump(pred, fout, protocol=pkl.HIGHEST_PROTOCOL)

        # compress. delete folder.
        cmdx = [
            "cd {} ".format(mask_fd),
            "cd .. ",
           # "tar -cf {}.tar.gz {}".format(name_fd_masks, name_fd_masks),
           # "rm -r {}".format(name_fd_masks)
        ]

        cmdx += [
            "cd {} ".format(bin_masks_fd),
            "cd .. ",
            "tar -cf {}.tar.gz {}".format(name_fd_masks_bin, name_fd_masks_bin),
            "rm -r {}".format(name_fd_masks_bin)
        ]

        cmdx = " && ".join(cmdx)
        print("Running bash-cmds: \n{}".format(cmdx.replace("&& ", "\n")))
        subprocess.run(cmdx, shell=True, check=True)
    else:
        return stats
    callback = None

    # ==========================================================
    # Data transformations: on PIL.Image.Image and torch.tensor.
    # ==========================================================

    train_transform_img = get_train_transforms_img(args)
    transform_tensor = get_transforms_tensor(args)

    # ==========================================================================
    # Datasets: create folds, load csv, preprocess files and save on disc.
    # load datasets: train, valid, test.
    # ==========================================================================

    announce_msg("SPLIT: {} \t FOLD: {}".format(args.split, args.fold))

    relative_fold_path = join(args.fold_folder, args.dataset,
                              "split_" + str(args.split),
                              "fold_" + str(args.fold))
    if isinstance(args.name_classes, str):  # path
        path_classes = join(relative_fold_path, args.name_classes)
        assert os.path.isfile(path_classes), "File {} does not exist .... " \
                                             "[NOT OK]".format(path_classes)
        with open(path_classes, "r") as fin:
            args.name_classes = yaml.load(fin)
    csvfiles = []
    for subp in ["train_s_", "valid_s_", "test_s_"]:
        csvfiles.append(
            join(relative_fold_path,
                 subp + str(args.split) + "_f_" + str(args.fold) + ".csv"))
Пример #6
0
from deepmil.decision_pooling import WildCatPoolDecision, ClassWisePooling
thread_lock = threading.Lock()  # lock for threads to protect the instruction that cause randomness and make them
# thread-safe.

import reproducibility

ACTIVATE_SYNC_BN = True
# Override ACTIVATE_SYNC_BN using variable environment in Bash:
# $ export ACTIVATE_SYNC_BN="True"   ----> Activate
# $ export ACTIVATE_SYNC_BN="False"   ----> Deactivate

if "ACTIVATE_SYNC_BN" in os.environ.keys():
    ACTIVATE_SYNC_BN = (os.environ['ACTIVATE_SYNC_BN'] == "True")

announce_msg("ACTIVATE_SYNC_BN was set to {}".format(ACTIVATE_SYNC_BN))

if check_if_allow_multgpu_mode() and ACTIVATE_SYNC_BN:  # Activate Synch-BN.
    from deepmil.syncbn import nn as NN_Sync_BN
    BatchNorm2d = NN_Sync_BN.BatchNorm2d
    announce_msg("Synchronized BN has been activated. \n"
                 "MultiGPU mode has been activated. {} GPUs".format(torch.cuda.device_count()))
else:
    BatchNorm2d = nn.BatchNorm2d
    if check_if_allow_multgpu_mode():
        announce_msg("Synchronized BN has been deactivated.\n"
                     "MultiGPU mode has been activated. {} GPUs".format(torch.cuda.device_count()))
    else:
        announce_msg("Synchronized BN has been deactivated.\n"
                     "MultiGPU mode has been deactivated. {} GPUs".format(torch.cuda.device_count()))
Пример #7
0
def create_k_folds_csv_bach_part_a(args):
    """
    Create k folds of the dataset BACH (part A) 2018 and store the image path of each fold in a *.csv file.

    1. Test set if fixed for all the splits/folds.
    2. We do a k-fold over the remaining data to create train, and validation sets.

    :param args: object, contain the arguments of splitting.
    :return:
    """
    announce_msg("Going to create the  splits, and the k-folds fro BACH (PART A) 2018 .... [OK]")

    rootpath = args.baseurl
    if args.dataset == "bc18bch":
        rootpath = join(rootpath, "ICIAR2018_BACH_Challenge/Photos")
    else:
        raise ValueError("Dataset name {} is unknown.".format(str(args.dataset)))

    samples = glob.glob(join(rootpath, "*", "*." + args.img_extension))
    # Originally, the function was written where 'samples' contains the absolute paths to the files.
    # Then, we realise that using absolute path on different platforms leads to a non-deterministic folds even
    # with the seed fixed. This is a result of glob.glob() that returns a list of paths more likely depending on the
    # how the files are saved within the OS. Therefore, to get rid of this, we use only a short path that is constant
    # across all the hosts where our dataset is saved. Then, we sort the list of paths BEFORE we go further. This
    # will guarantee that whatever the OS the code is running in, the sorted list is the same.
    samples = sorted([join(*sx.split(os.sep)[-4:]) for sx in samples])

    classes = {key: [s for s in samples if s.split(os.sep)[-2] == key] for key in args.name_classes.keys()}

    all_train = {}
    test_fix = []
    # Shuffle to avoid any bias.
    for key in classes.keys():
        for i in range(1000):
            random.shuffle(classes[key])

        nbr_test = int(len(classes[key]) * args.test_portion)
        test_fix += classes[key][:nbr_test]
        all_train[key] = classes[key][nbr_test:]

    # Test set is ready. Now, we need to do k-fold over the train.

    # Create the splits over the train
    splits = []
    for i in range(args.nbr_splits):
        for t in range(1000):
            for k in all_train.keys():
                random.shuffle(all_train[k])

        splits.append(copy.deepcopy(all_train))

    readme = "csv format:\n" \
             "relative path to the image file.\n" \
             "Example:\n" \
             "ICIAR2018_BACH_Challenge/Photos/Normal/n047.tif\n" \
             "There are four classes: normal, benign, in situ, and invasive.\n" \
             "The class of the sample may be infered from the parent folder of the image (in the example " \
             "above:\n " \
             "Normal:\n" \
             "Normal: class 'normal'\n" \
             "Benign: class 'benign'\n" \
             "InSitu: class 'in situ'\n" \
             "Invasive: class 'invasive'"

    # Create k-folds for each split.
    def create_folds_of_one_class(lsamps, s_tr, s_vl):
        """
        Create k folds from a list of samples of the same class, each fold contains a train, and valid set with a
        predefined size.

        Samples need to be shuffled beforehand.

        :param lsamps: list of paths to samples of the same class.
        :param s_tr: int, number of samples in the train set.
        :param s_vl: int, number of samples in the valid set.
        :return: list_folds: list of k tuples (tr_set, vl_set, ts_set): where each element is the list (str paths)
                 of the samples of each set: train, valid, and test, respectively.
        """
        assert len(lsamps) == s_tr + s_vl, "Something wrong with the provided sizes .... [NOT OK]"

        # chunk the data into chunks of size ts (the size of the test set), so we can rotate the test set.
        list_chunks = list(chunk_it(lsamps, s_vl))
        list_folds = []

        for i in range(len(list_chunks)):
            vl_set = list_chunks[i]

            right, left = [], []
            if i < len(list_chunks) - 1:
                right = list_chunks[i + 1:]
            if i > 0:
                left = list_chunks[:i]

            leftoverchunks = right + left

            leftoversamples = []
            for e in leftoverchunks:
                leftoversamples += e

            tr_set = leftoversamples
            list_folds.append((tr_set, vl_set))

        return list_folds

    # Save the folds into *.csv files.
    def dump_fold_into_csv(lsamples, outpath):
        """
        Write a list of RELATIVE paths into a csv file.
        Relative paths allow running the code an any device.
        The absolute path within the device will be determined at the running time.

        csv file format: relative path to the image, relative path to the mask, class (str: benign, malignant).

        :param lsamples: list of str of relative paths.
        :param outpath: str, output file name.
        :return:
        """
        with open(outpath, 'w') as fcsv:
            filewriter = csv.writer(fcsv, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            for fname in lsamples:
                filewriter.writerow([fname])

    def create_one_split(split_i, test_samples, train_samples_all, nbr_folds):
        """
        Create one split of k-folds.

        :param split_i: int, the id of the split.
        :param test_samples: dict of list, each key represents a class (test set, fixed).
        :param train_samples_all: dict of list, each key represent a class (all train set).
        :param nbr_folds: int, number of folds [the k value in k-folds].
        :return:
        """
        # Create the k-folds
        list_folds_of_class = {}

        for key in train_samples_all.keys():
            vl_size = math.ceil(len(train_samples_all[key]) * args.folding["vl"] / 100.)
            tr_size = len(train_samples_all[key]) - vl_size
            list_folds_of_class[key] = create_folds_of_one_class(train_samples_all[key], tr_size, vl_size)

            assert len(list_folds_of_class[key]) == nbr_folds, "We didn't get `{}` folds, but `{}` .... " \
                                                               "[NOT OK]".format(
                nbr_folds, len(list_folds_of_class[key]))

            print("We obtained `{}` folds for the class {}.... [OK]".format(args.nbr_folds, key))

        outd = args.fold_folder
        for i in range(nbr_folds):
            print("Fold {}:\n\t".format(i))
            out_fold = join(outd, "split_" + str(split_i) + "/fold_" + str(i))
            if not os.path.exists(out_fold):
                os.makedirs(out_fold)

            # dump the test set
            dump_fold_into_csv(test_samples, join(out_fold, "test_s_" + str(split_i) + "_f_" + str(i) + ".csv"))

            train = []
            valid = []
            for key in list_folds_of_class.keys():
                # Train
                train += list_folds_of_class[key][i][0]

                # Valid.
                valid += list_folds_of_class[key][i][1]

            # shuffle
            for t in range(1000):
                random.shuffle(train)

            dump_fold_into_csv(train, join(out_fold, "train_s_" + str(split_i) + "_f_" + str(i) + ".csv"))
            dump_fold_into_csv(valid, join(out_fold, "valid_s_" + str(split_i) + "_f_" + str(i) + ".csv"))

            # dump the seed
            with open(join(out_fold, "seed.txt"), 'w') as fx:
                fx.write("MYSEED: " + os.environ["MYSEED"])

            with open(join(out_fold, "readme.md"), 'w') as fx:
                fx.write(readme)
        print("BACH (PART A) 2018 splitting N° `{}` ends with success .... [OK]".format(split_i))

    if not os.path.isdir(args.fold_folder):
        os.makedirs(args.fold_folder)

    # Creates the splits
    for i in range(args.nbr_splits):
        print("Split {}:\n\t".format(i))
        create_one_split(i, test_fix, splits[i], args.nbr_folds)

    with open(join(args.fold_folder, "readme.md"), 'w') as fx:
        fx.write(readme)
    print("All BACH (PART A) 2018 splitting (`{}`) ended with success .... [OK]".format(args.nbr_splits))
def validate(model,
             dataset,
             dataloader,
             criterion,
             device,
             stats,
             args,
             folderout=None,
             epoch=0,
             log_file=None,
             name_set="",
             store_on_disc=False,
             store_imgs=False,
             final_mode=False,
             seg_threshold=None
             ):
    """
    Perform a validation over the validation set. Assumes a batch size of 1.
    (images do not have the same size,
    so we can't stack them in one tensor).
    Validation samples may be large to fit all in the GPU at once.

    Note: criterion is deppmil.criteria.TotalLossEval().
    """
    model.eval()
    fl_thres = seg_threshold if seg_threshold is not None else args.final_thres

    metrics = Metrics(threshold=fl_thres).to(device)
    dice_f = Dice()
    iou_f = IOU()
    metrics.eval()

    f1pos_, f1neg_, miou_, acc_ = 0., 0., 0., 0.
    l_f1pos_ = []
    l_f1neg_ = []
    l_miou_ = []

    cnt = 0.
    total_loss_ = 0.
    loss_pos_ = 0.
    loss_neg_ = 0.

    # camelyon16
    sizes_m = {
        'm_pred': [],  # metastatic.
        'm_true': [],  # metastatic
        'n_pred': [],  # normal
        'n_true': []  # normal
    }

    mask_fd = None
    name_fd_masks = "masks"  # where to store the predictions.
    name_fd_masks_bin = "masks_bin"  # where to store the bin masks and al.
    if folderout is not None:
        mask_fd = join(folderout, name_fd_masks)
        bin_masks_fd = join(folderout, name_fd_masks_bin)
        if not os.path.exists(mask_fd):
            os.makedirs(mask_fd)

        if not os.path.exists(bin_masks_fd):
            os.makedirs(bin_masks_fd)

    length = len(dataloader)
    t0 = dt.datetime.now()
    myseed = int(os.environ["MYSEED"])
    masks_sizes = []
    cancer_scores = []
    glabels = []

    avg_forward_t = dt.timedelta(0)

    with torch.no_grad():
        for i, (data, mask, label) in tqdm.tqdm(
                enumerate(dataloader), ncols=80, total=length):

            reproducibility.force_seed(myseed + epoch + 1)

            msg = "Expected a batch size of 1. Found `{}`  .... " \
                  "[NOT OK]".format(data.size()[0])
            assert data.size()[0] == 1, msg
            bsz = data.size()[0]

            data = data.to(device)
            labels = label.to(device)
            mask = mask[0].clone()
            mask_t = mask.unsqueeze(0).to(device)
            assert mask_t.ndim == 4, "ndim = {} must be 4.".format(mask_t.ndim)

            # In validation, we do not need reproducibility since everything
            # is expected to deterministic. Plus,
            # we use only one gpu since the batch size os 1.
            t0 = dt.datetime.now()
            scores_pos, scores_neg, mask_pred, sc_cl_se, cams = model(
                x=data, glabels=labels, seed=None)
            delta_t = dt.datetime.now() - t0
            avg_forward_t += delta_t

            t_loss, l_p, l_n, l_seg = criterion(
                scores_pos,
                sc_cl_se,
                labels,
                mask_pred,
                scores_neg,
                cams
            )


            mask_pred = mask_pred.squeeze()
            # check sizes of the mask:
            _, _, h, w = mask_t.shape
            hp, wp = mask_pred.shape

            assert args.padding_size in [None, 'None', 0.0], args.padding_size
            mask_pred = F.interpolate(
                input=mask_pred.unsqueeze(0).unsqueeze(0),
                size=(h, w),
                mode='bilinear',
                align_corners=True
            ).squeeze()

            mask_pred = mask_pred.unsqueeze(0).unsqueeze(0)

            if args.dataset == constants.GLAS:
                cl_scores = scores_pos
                acc, dice_forg, dice_back, miou = metrics(
                    scores=cl_scores,
                    labels=labels,
                    masks_pred=mask_pred.contiguous().view(bsz, -1),
                    masks_trg=mask_t.contiguous().view(bsz, -1),
                    avg=False
                )
                f1pos_ += dice_forg
                f1neg_ += dice_back
                miou_ += miou

            elif args.dataset == constants.CAMELYON16P512:
                cl_scores = sc_cl_se  # sc_cl_se, scores_pos

                plabels = metrics.predict_label(cl_scores)
                acc = ((plabels - labels) == 0.).float().sum()
                assert data.size()[0] == 1
                ppixels = metrics.get_binary_mask(
                    mask_pred.contiguous().view(bsz, -1), metrics.threshold)
                masks_trg = mask_t.contiguous().view(bsz, -1)
                dice_forg = 0.
                dice_back = 0.
                miou = 0.

                glabels.append(labels.item())
                masks_sizes.append(ppixels.mean().item())

                if labels.item() == 1:
                    sizes_m['m_pred'].append(ppixels.mean().item() * 100.)
                    sizes_m['m_true'].append(masks_trg.mean().item() * 100.)
                elif labels.item() == 0:
                    sizes_m['n_pred'].append(ppixels.mean().item() * 100.)
                    sizes_m['n_true'].append(0.0)
                else:
                    raise ValueError

                cancer_scores.append(
                    torch.softmax(cl_scores, dim=1)[0, 1].item())

                if masks_trg.sum() > 0:
                    l_f1pos_.append(dice_f(ppixels.view(bsz, -1),
                                    masks_trg.view(bsz, -1)))
                    dice_forg = l_f1pos_[-1]

                    if (1. - masks_trg).sum() > 0:
                        tmp = iou_f(ppixels.view(1, -1),
                                    masks_trg.view(1, -1))
                        tmp = tmp + iou_f(1. - ppixels.view(1, -1),
                                          1. - masks_trg.view(1, -1))
                        l_miou_.append(tmp / 2.)
                        miou = l_miou_[-1]

                if (1. - masks_trg).sum() > 0:
                    l_f1neg_.append(dice_f(1. - ppixels.view(bsz, -1),
                                           1. - masks_trg.view(bsz, -1)))
                    dice_back = l_f1neg_[-1]

            else:
                raise NotImplementedError

            acc_ += acc
            cnt += bsz
            total_loss_ += t_loss.item()
            loss_pos_ += l_p.item()
            loss_neg_ += l_n.item()

            if (folderout is not None) and store_on_disc:
                # binary mask
                bin_pred_mask = metrics.get_binary_mask(mask_pred).squeeze()
                bin_pred_mask = bin_pred_mask.cpu().detach().numpy().astype(np.bool)
                if args.dataset == constants.GLAS:
                    to_save = {
                        "bin_pred_mask": bin_pred_mask,
                        "continuous_mask": mask_pred.cpu().detach().numpy(),
                        "dice_forg": dice_forg,
                        "dice_back": dice_back,
                        "i": i
                    }
                elif args.dataset == constants.CAMELYON16P512:
                    to_save = {
                        "bin_pred_mask": bin_pred_mask,
                        "dice_forg": dice_forg,
                        "dice_back": dice_back,
                        "i": i
                    }
                else:
                    raise NotImplementedError

                with open(join(bin_masks_fd, "{}.pkl".format(i)), "wb") as fbin:
                    pkl.dump(to_save, fbin, protocol=pkl.HIGHEST_PROTOCOL)

            if (folderout is not None) and store_imgs and store_on_disc:
                pred_label = int(cl_scores.argmax().item())
                probs = softmax(cl_scores.cpu().detach().numpy())
                prob = float(probs[0, pred_label])

                store_pred_img(i,
                               dataset,
                               bin_pred_mask * 1.,
                               mask_pred.squeeze().cpu().detach().numpy(),
                               dice_forg,
                               dice_back,
                               prob,
                               pred_label,
                               args,
                               mask_fd,
                               )

    # avg
    acc_ *= (100. / float(cnt))
    if args.dataset == constants.CAMELYON16P512:
        with open(join(folderout,
                           'log-cl-{}-final-{}.txt'.format(epoch,
                                                           final_mode)),
                  'a') as fz:
            fz.write("\nClassification accuracy: {} (%)".format(acc_))

        with open(join(folderout,
                       'log-seg-{}-final-{}.txt'.format(
                           epoch, final_mode)), 'a') as fz:
            fz.write("\nClassification accuracy: {} (%)".format(acc_))

    total_loss_ /= float(cnt)
    loss_pos_ /= float(cnt)
    loss_neg_ /= float(cnt)
    avg_forward_t /= float(cnt)

    if args.dataset == constants.GLAS:
        f1pos_ *= (100. / float(cnt))
        f1neg_ *= (100. / float(cnt))
        miou_ *= (100. / float(cnt))
    elif args.dataset == constants.CAMELYON16P512:
        f1pos_ = 0.
        f1neg_ = 0.
        miou_ = 0.
        if l_f1pos_:
            f1pos_ = torch.stack(l_f1pos_).mean() * 100.
        if l_f1neg_:
            f1neg_ = torch.stack(l_f1neg_).mean() * 100.
        if l_miou_:
            miou_ = torch.stack(l_miou_).mean() * 100.
    else:
        raise NotImplementedError

    if stats is not None:
        stats["total_loss"].append(total_loss_)
        stats["loss_pos"].append(loss_pos_)
        stats["loss_neg"].append(loss_neg_)
        stats["acc"].append(acc_)
        stats["f1pos"].append(f1pos_)
        stats["f1neg"].append(f1neg_)
        stats['miou'].append(miou_)

    to_write = "EVAL ({}): TLoss: {:.2f}, L+: {:.2f}, L-: {:.2f}, " \
               "F1+: {:.2f}%, F1-: {:.2f}%, MIOU: {:.2f}%, ACC: {:.2f}%, " \
               "t:{}, epoch {:>2d}.".format(
        name_set,
        total_loss_,
        loss_pos_,
        loss_neg_,
        f1pos_,
        f1neg_,
        miou_,
        acc_,
        dt.datetime.now() - t0,
        epoch
        )

    print(to_write)
    if log_file:
        log(log_file, to_write)

    if final_mode:
        assert folderout is not None
        msg = "EVAL {}: \n".format(name_set)
        msg += "ACC {}% \n".format(acc_)
        msg += "F1+ {}% \n".format(f1pos_)
        msg += "F1- {}% \n".format(f1neg_)
        msg += "MIOU {}% \n".format(miou_)
        announce_msg(msg)
        if log_file:
            log(log_file, msg)

        with open(join(folderout, 'avg_forward_time.txt'), 'w') as fend:
            fend.write("Model: {}. \n Average forward time (eval mode): "
                       " {}.".format(args.model['model_name'], avg_forward_t
            ))

    if (folderout is not None) and store_on_disc:
        pred = {
            "total_loss": total_loss_,
            "loss_pos": loss_pos_,
            "loss_neg": loss_neg_,
            "acc": acc_,
            "f1pos": f1pos_,
            "f1neg": f1neg_,
            "miou_": miou_
        }
        with open(
                join(folderout, "pred--{}.pkl".format(name_set)), "wb") as fout:
            pkl.dump(pred, fout, protocol=pkl.HIGHEST_PROTOCOL)

        # compress. delete folder.
        cmdx = [
            "cd {} ".format(mask_fd),
            "cd .. ",
           # "tar -cf {}.tar.gz {}".format(name_fd_masks, name_fd_masks),
           # "rm -r {}".format(name_fd_masks)
        ]

        cmdx += [
            "cd {} ".format(bin_masks_fd),
            "cd .. ",
            "tar -cf {}.tar.gz {}".format(name_fd_masks_bin, name_fd_masks_bin),
            "rm -r {}".format(name_fd_masks_bin)
        ]

        cmdx = " && ".join(cmdx)
        print("Running bash-cmds: \n{}".format(cmdx.replace("&& ", "\n")))
        try:
            subprocess.run(cmdx, shell=True, check=True)
        except subprocess.SubprocessError as e:
            print("Failed to run: {}. Error: {}".format(cmdx, e))

    else:
        return stats