Exemplo n.º 1
0
    def __call__(self, probs: Tensor, target: Tensor) -> Tensor:
        assert simplex(probs)
        assert simplex(target)
        assert probs.shape == target.shape

        B, K, *xyz = probs.shape  # type: ignore

        pc = cast(Tensor, probs[:, self.idc, ...].type(torch.float32))
        tc = cast(Tensor, target[:, self.idc, ...].type(torch.float32))
        assert pc.shape == tc.shape == (B, len(self.idc), *xyz)

        target_dm_npy: np.ndarray = np.stack([one_hot2hd_dist(tc[b].cpu().detach().numpy())
                                              for b in range(B)], axis=0)
        assert target_dm_npy.shape == tc.shape == pc.shape
        tdm: Tensor = torch.tensor(target_dm_npy, device=probs.device, dtype=torch.float32)

        pred_segmentation: Tensor = probs2one_hot(probs).cpu().detach()
        pred_dm_npy: np.nparray = np.stack([one_hot2hd_dist(pred_segmentation[b, self.idc, ...].numpy())
                                            for b in range(B)], axis=0)
        assert pred_dm_npy.shape == tc.shape == pc.shape
        pdm: Tensor = torch.tensor(pred_dm_npy, device=probs.device, dtype=torch.float32)

        delta = (pc - tc)**2
        dtm = tdm**2 + pdm**2

        multipled = einsum("bkwh,bkwh->bkwh", delta, dtm)

        loss = multipled.mean()

        return loss
Exemplo n.º 2
0
def compute_metrics(pred_probs, gt, labels):

    predicted_mask: Tensor = probs2one_hot(pred_probs)
    b, c, _,_ = predicted_mask.shape

    dices = dice_coef(predicted_mask.detach(), gt.detach()).cpu().numpy()
    baseline_dices = dice_coef(labels.detach(), gt.detach()).cpu().numpy()
    haussdorf_res = haussdorf(predicted_mask.detach(), gt.detach(), dtype= pred_probs.dtype).cpu().numpy()

    assert haussdorf_res.shape == (b, c)
    posim = torch.einsum("bcwh->b", [gt[:, 1:, :, :]]).detach() > 0
    posim = posim.cpu().numpy()

    return dices, baseline_dices, posim, haussdorf_res
Exemplo n.º 3
0
    def __call__(self, probs: Tensor, target: Tensor, bounds) -> Tensor:
        assert simplex(
            probs
        )  # and simplex(target)  # Actually, does not care about second part
        b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
        predicted_mask = probs2one_hot(probs).detach()
        est_prop_mask = self.__fn__(predicted_mask, self.power).squeeze(2)
        est_prop: Tensor = self.__fn__(probs, self.power)
        if self.curi:
            if self.ivd:
                bounds = bounds[:, :, 0]
                bounds = bounds.unsqueeze(2)
            gt_prop = torch.ones_like(est_prop) * bounds / (w * h)
            gt_prop = gt_prop[:, :, 0]
        else:
            gt_prop: Tensor = self.__fn__(
                target, self.power
            )  # the power here is actually useless if we have 0/1 gt labels
        if not self.curi:
            gt_prop = gt_prop.squeeze(2)
        est_prop = est_prop.squeeze(2)
        log_est_prop: Tensor = (est_prop + 1e-10).log()

        log_gt_prop: Tensor = (gt_prop + 1e-10).log()
        log_est_prop_mask: Tensor = (est_prop_mask + 1e-10).log()

        loss_cons_prior = -torch.einsum(
            "bc,bc->", [est_prop, log_gt_prop]) + torch.einsum(
                "bc,bc->", [est_prop, log_est_prop])
        # Adding division by batch_size to normalise
        loss_cons_prior /= b
        log_p: Tensor = (probs + 1e-10).log()
        mask: Tensor = probs.type((torch.float32))
        mask_weighted = torch.einsum(
            "bcwh,c->bcwh", [mask, Tensor(self.weights).to(mask.device)])
        loss_se = -torch.einsum("bcwh,bcwh->", [mask_weighted, log_p])
        loss_se /= mask.sum() + 1e-10

        assert loss_se.requires_grad == probs.requires_grad  # Handle the case for validation

        return self.lamb_se * loss_se, self.lamb_consprior * loss_cons_prior, est_prop
Exemplo n.º 4
0
    def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor:
        assert simplex(probs)
        assert probs.shape == target.shape
        assert len(self.mask_idc) == 1, "Cannot handle more at the time, I guess"

        b, c, w, h = probs.shape

        fake_probs: Tensor = torch.zeros_like(probs, dtype=torch.float32)
        for i in range(len(probs)):
            low: Tensor = bounds[i, self.mask_idc][0, 0, 0]
            high: Tensor = bounds[i, self.mask_idc][0, 0, 1]

            res = self.pathak_generator(probs[i].detach(), target[i].detach(), low, high)
            assert simplex(res, axis=0)
            assert res.shape == (c, w, h)

            fake_probs[i] = res
        fake_mask: Tensor = probs2one_hot(fake_probs)
        assert fake_mask.shape == probs.shape == target.shape

        return super().__call__(probs, fake_mask, bounds)
Exemplo n.º 5
0
    def pathak_generator(self, probs: Tensor, target: Tensor, bounds) -> Tensor:
        _, w, h = probs.shape

        # Replace the probabilities with certainty for the few weak labels that we have
        weak_labels = target[...]
        weak_labels[self.ignore, ...] = 0
        assert not simplex(weak_labels) and simplex(target)
        lower, upper = bounds[-1]

        labeled_pixels = weak_labels.any(axis=0)
        assert w * h == (labeled_pixels.sum() + (~labeled_pixels).sum())  # make sure all pixels are covered
        scribbled_probs = weak_labels + einsum("cwh,wh->cwh", probs, ~labeled_pixels)
        assert simplex(scribbled_probs)

        u: Tensor
        max_iter: int = 100
        lr: float = 0.00005
        b: Tensor = Tensor([-lower, upper])
        beta: Tensor = torch.zeros(2, torch.float32)
        f: Tensor = torch.zeros(2, *probs.shape)
        f[0, ...] = -1
        f[1, ...] = 1

        for i in range(max_iter):
            exped = - einsum("i,icwh->cwh", beta, f).exp()
            u_star = einsum('cwh,cwh->cwh', probs, exped)
            u_star /= u_star.sum(axis=0)
            assert simplex(u_star)

            d_beta = einsum("cwh,icwh->i", u_star, f) - b
            n_beta = torch.max(torch.zeros_like(beta), beta + lr * d_beta)

            u = u_star
            beta = n_beta

        return probs2one_hot(u)
Exemplo n.º 6
0
def extra_eval_model(model, dataloaders, log_dir="./log/", logger=None, opt=None):
    since = time.time()
    if True:
        # eval infection segmentation and cls
        logger.info("-"*8+"extra eval infection cls"+"-"*8)
        model.eval()


        val_gt = []
        val_cls_pred = []
        val_cls_probs = [] # for VOC
        val_seg_pred = [] 
        val_seg_probs = [] # for VOC

        val_seg_probs_au = []
        val_seg_pred_au = [] # for VOC



        for batch_idx, (inputs, labels) in enumerate(dataloaders["tgt_cls_extra_val"], 0):
            inputs = inputs.to(device)
            # adjust label
            val_gt.append(labels.cpu().data.numpy())

            with torch.set_grad_enabled(False):
                annotation = dataloaders["tgt_cls_extra_val"].dataset.annotations[batch_idx]
                img_dir = annotation.strip().split(',')[0]
                img_name = Path(img_dir).name

                if opt.use_aux:
                    cls_logits, _, seg_logits, _, seg_logits_au = model(inputs)
                else:
                    cls_logits, _, seg_logits, _, _ = model(inputs)

                if opt.do_seg:
                    seg_probs = torch.softmax(seg_logits, dim=1)
                    val_seg_probs.append(seg_probs[:,-1:].detach().cpu().view(seg_probs.shape[0], 1, -1).max(-1)[0])

                    predicted_mask_onehot = probs2one_hot(seg_probs.detach())

                    # for save
                    predicted_mask = predicted_mask_onehot.squeeze().cpu().numpy() # 3xwxh
                    mask_inone = (np.zeros_like(predicted_mask[0])+predicted_mask[1]*128+predicted_mask[2]*255).astype(np.uint8)

                    # save dir:
                    save_dir = os.path.join(opt.logs, "tgt_cls_extra_val", "eval")
                    # 
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                
                    cv2.imwrite(os.path.join(save_dir, img_name), mask_inone)
                    # seg2cls 
                    preds_cls_seg = (predicted_mask_onehot[:,-1:].sum(-1).sum(-1) > 0).cpu().numpy().astype(np.uint8)
                    val_seg_pred.append(preds_cls_seg)
                    
                if opt.do_seg and opt.use_aux:
                    seg_probs_au = torch.softmax(seg_logits_au, dim=1)
                    val_seg_probs_au.append(seg_probs_au[:,-1:].detach().cpu().view(seg_probs_au.shape[0], 1, -1).max(-1)[0])

                    predicted_mask_onehot_au = probs2one_hot(seg_probs_au.detach())

                    # for save
                    predicted_mask_au = predicted_mask_onehot_au.squeeze().cpu().numpy() # 3xwxh
                    mask_inone_au = (np.zeros_like(predicted_mask_au[0])+predicted_mask_au[1]*128+predicted_mask_au[2]*255).astype(np.uint8)

                    # save dir:
                    save_dir_au = os.path.join(opt.logs, "tgt_cls_extra_val_au", "eval")
                    # 
                    if not os.path.exists(save_dir_au):
                        os.makedirs(save_dir_au)
                
                    cv2.imwrite(os.path.join(save_dir_au, img_name), mask_inone_au)
                    # seg2cls 
                    preds_cls_seg_au = (predicted_mask_onehot_au[:,-1:].sum(-1).sum(-1) > 0).cpu().numpy().astype(np.uint8)
                    val_seg_pred_au.append(preds_cls_seg_au)                

                # cls
                #print(cls_logits)
                if opt.do_cls:
                    probs_cls = torch.softmax(cls_logits, dim=1)
                    val_cls_probs.append(probs_cls[...,1:].detach().cpu().numpy())
                    preds_cls = (probs_cls[...,1:] > 0.5).type(torch.long)
                    val_cls_pred.append(preds_cls.cpu().data.numpy())

        if not os.path.exists(os.path.join(opt.logs, "cf")):
            os.makedirs(os.path.join(opt.logs, "cf"))

        val_gt = np.concatenate(val_gt, axis=0)


        if opt.do_cls:

            val_cls_pred = np.concatenate(val_cls_pred, axis=0)
            val_cls_probs = np.concatenate(val_cls_probs, axis=0)


            save_cf_png_dir = os.path.join(opt.logs, "cf", "extra_eval_cls_cf.png")
            save_metric_dir = os.path.join(opt.logs, "extra_eval_metric_cls.txt")
            result_str = get_results(val_gt, val_cls_pred, val_cls_probs, save_cf_png_dir, save_metric_dir)            
            logger.info("tgt_cls_extra_val:[cls]: %s" % (result_str))

        if opt.do_seg:
            val_seg_pred = np.concatenate(val_seg_pred, axis=0)
            val_seg_probs = np.concatenate(val_seg_probs, axis=0)

            # seg2cls
            save_cf_png_dir = os.path.join(opt.logs, "cf", "extra_eval_seg_cf.png")
            save_metric_dir = os.path.join(opt.logs, "extra_eval_metric_seg.txt")

            result_str = get_results(val_gt, val_seg_pred, val_seg_probs, save_cf_png_dir, save_metric_dir)
            logger.info("tgt_seg_extra_val:[seg2cls]: %s" % (result_str))

        if opt.do_seg and opt.use_aux:
            val_seg_pred_au = np.concatenate(val_seg_pred_au, axis=0)
            val_seg_probs_au = np.concatenate(val_seg_probs_au, axis=0)

            # seg2cls
            save_cf_png_dir_au = os.path.join(opt.logs, "cf", "extra_eval_seg_au_cf.png")
            save_metric_dir_au = os.path.join(opt.logs, "extra_eval_metric_seg_au.txt")

            result_str_au = get_results(val_gt, val_seg_pred_au, val_seg_probs_au, save_cf_png_dir_au, save_metric_dir_au)
            logger.info("tgt_seg_au_extra_val:[seg2cls]: %s" % (result_str_au))

    time_elapsed = time.time() - since
    logger.info("Extra_Eval complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
Exemplo n.º 7
0
def eval_model(model, dataloaders, log_dir="./log/", logger=None, opt=None):

    since = time.time()
    if False:#opt.do_seg:
        # eval lung segmentation
        logger.info("-"*8+"eval lung segmentation"+"-"*8)

        model.eval()
        all_dices = []
        all_dices_au = []

        for batch_idx, (inputs, labels) in enumerate(dataloaders["tgt_lung_seg_val"], 0):
            annotation = dataloaders["tgt_lung_seg_val"].dataset.annotations[batch_idx]
            img_dir = annotation.strip().split(',')[0]
            img_name = Path(img_dir).name                
            
            
            inputs = inputs.to(device)
            # adjust labels
            
            labels[labels==opt.xray_mask_value_dict["lung"]] = 1
            
            labels = labels[:,-1].to(device)
            labels = torch.stack([labels == c for c in range(2)], dim=1)

            
            
            with torch.set_grad_enabled(False):
                if opt.use_aux:
                    _, _, seg_logits, _, seg_logits_au = model(inputs)
                else:
                    _, _, seg_logits, _, _ = model(inputs)

                seg_probs = torch.softmax(seg_logits, dim=1)
                predicted_mask = probs2one_hot(seg_probs.detach())

                # change the infection to Lung
                predicted_mask_lung = predicted_mask[:,:-1]
                predicted_mask_lung[:,-1] += predicted_mask[:,-1]
                dices = dice_coef(predicted_mask_lung, labels.detach().type_as(predicted_mask)).cpu().numpy()


                all_dices.append(dices) # [(B,C)]

                predicted_mask_lung = predicted_mask_lung.squeeze().cpu().numpy() # 3xwxh
                mask_inone = (np.zeros_like(predicted_mask_lung[0])+predicted_mask_lung[1]*255).astype(np.uint8)

                # save dir:
                save_dir = os.path.join(opt.logs, "tgt_lung_seg_val", "eval")
                # 
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                
                cv2.imwrite(os.path.join(save_dir, img_name), mask_inone)

                ###################################################au
                if opt.use_aux:
                    seg_probs_au = torch.softmax(seg_logits_au, dim=1)
                    predicted_mask_au = probs2one_hot(seg_probs_au.detach())

                    # change the infection to Lung
                    predicted_mask_lung_au = predicted_mask_au[:,:-1]
                    predicted_mask_lung_au[:,-1] += predicted_mask_au[:,-1]
                    dices_au = dice_coef(predicted_mask_lung_au, labels.detach().type_as(predicted_mask_au)).cpu().numpy()


                    all_dices_au.append(dices_au) # [(B,C)]

                    predicted_mask_lung_au = predicted_mask_lung_au.squeeze().cpu().numpy() # 3xwxh
                    mask_inone_au = (np.zeros_like(predicted_mask_lung_au[0])+predicted_mask_lung_au[1]*255).astype(np.uint8)

                    # save dir:
                    save_dir_au = os.path.join(opt.logs, "tgt_lung_seg_val_au", "eval")
                    # 
                    if not os.path.exists(save_dir_au):
                        os.makedirs(save_dir_au)
                    
                    cv2.imwrite(os.path.join(save_dir_au, img_name), mask_inone_au)


            avg_dice = np.mean(np.concatenate(all_dices, 0), 0) #


            logger.info("tgt_lung_seg_val:[%d/%d],dice0:%.03f,dice1:%.03f,dice:%.03f" 
                % (batch_idx, len(dataloaders['tgt_lung_seg_val'].dataset)//inputs.shape[0], 
                avg_dice[0], avg_dice[1], np.mean(np.concatenate(all_dices, 0))))
            if opt.use_aux:
                avg_dice_au = np.mean(np.concatenate(all_dices_au, 0), 0) #
                logger.info("tgt_lung_seg_val_au:[%d/%d],dice0:%.03f,dice1:%.03f,dice:%.03f" 
                    % (batch_idx, len(dataloaders['tgt_lung_seg_val'].dataset)//inputs.shape[0], 
                    avg_dice_au[0], avg_dice_au[1], np.mean(np.concatenate(all_dices_au, 0))))

    if True:
        # eval infection segmentation and cls
        logger.info("-"*8+"eval infection cls"+"-"*8)
        model.eval()


        val_gt = []
        val_cls_pred = []
        val_cls_probs = [] # for VOC
        val_seg_pred = [] 
        val_seg_probs = [] # for VOC

        val_seg_probs_au = []
        val_seg_pred_au = [] # for VOC



        for batch_idx, (inputs, labels) in enumerate(dataloaders["tgt_cls_val"], 0):
            inputs = inputs.to(device)
            # adjust label
            val_gt.append(labels.cpu().data.numpy())

            with torch.set_grad_enabled(False):
                annotation = dataloaders["tgt_cls_val"].dataset.annotations[batch_idx]
                img_dir = annotation.strip().split(',')[0]
                img_name = Path(img_dir).name

                if opt.use_aux:
                    cls_logits, _, seg_logits, _, seg_logits_au = model(inputs)
                else:
                    cls_logits, _, seg_logits, _, _ = model(inputs)

                if opt.do_seg:
                    seg_probs = torch.softmax(seg_logits, dim=1)
                    val_seg_probs.append(seg_probs[:,-1:].detach().cpu().view(seg_probs.shape[0], 1, -1).max(-1)[0])

                    predicted_mask_onehot = probs2one_hot(seg_probs.detach())

                    # for save
                    predicted_mask = predicted_mask_onehot.squeeze().cpu().numpy() # 3xwxh
                    mask_inone = (np.zeros_like(predicted_mask[0])+predicted_mask[1]*128+predicted_mask[2]*255).astype(np.uint8)

                    # save dir:
                    save_dir = os.path.join(opt.logs, "tgt_cls_val", "eval")
                    # 
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                
                    cv2.imwrite(os.path.join(save_dir, img_name), mask_inone)
                    # seg2cls 
                    preds_cls_seg = (predicted_mask_onehot[:,-1:].sum(-1).sum(-1) > 0).cpu().numpy().astype(np.uint8)
                    val_seg_pred.append(preds_cls_seg)
                    
                if opt.do_seg and opt.use_aux:
                    seg_probs_au = torch.softmax(seg_logits_au, dim=1)
                    val_seg_probs_au.append(seg_probs_au[:,-1:].detach().cpu().view(seg_probs_au.shape[0], 1, -1).max(-1)[0])

                    predicted_mask_onehot_au = probs2one_hot(seg_probs_au.detach())

                    # for save
                    predicted_mask_au = predicted_mask_onehot_au.squeeze().cpu().numpy() # 3xwxh
                    mask_inone_au = (np.zeros_like(predicted_mask_au[0])+predicted_mask_au[1]*128+predicted_mask_au[2]*255).astype(np.uint8)

                    # save dir:
                    save_dir_au = os.path.join(opt.logs, "tgt_cls_val_au", "eval")
                    # 
                    if not os.path.exists(save_dir_au):
                        os.makedirs(save_dir_au)
                
                    cv2.imwrite(os.path.join(save_dir_au, img_name), mask_inone_au)
                    # seg2cls 
                    preds_cls_seg_au = (predicted_mask_onehot_au[:,-1:].sum(-1).sum(-1) > 0).cpu().numpy().astype(np.uint8)
                    val_seg_pred_au.append(preds_cls_seg_au)                

                # cls
                #print(cls_logits)
                if opt.do_cls:
                    probs_cls = torch.softmax(cls_logits, dim=1)
                    val_cls_probs.append(probs_cls[...,1:].detach().cpu().numpy())
                    preds_cls = (probs_cls[...,1:] > 0.5).type(torch.long)
                    val_cls_pred.append(preds_cls.cpu().data.numpy())

        if not os.path.exists(os.path.join(opt.logs, "cf")):
            os.makedirs(os.path.join(opt.logs, "cf"))

        val_gt = np.concatenate(val_gt, axis=0)


        if opt.do_cls:

            val_cls_pred = np.concatenate(val_cls_pred, axis=0)
            val_cls_probs = np.concatenate(val_cls_probs, axis=0)


            save_cf_png_dir = os.path.join(opt.logs, "cf", "eval_cls_cf.png")
            save_metric_dir = os.path.join(opt.logs, "eval_metric_cls.txt")
            result_str = get_results(val_gt, val_cls_pred, val_cls_probs, save_cf_png_dir, save_metric_dir)            
            logger.info("tgt_cls_val:[cls]: %s" % (result_str))

        if opt.do_seg:
            val_seg_pred = np.concatenate(val_seg_pred, axis=0)
            val_seg_probs = np.concatenate(val_seg_probs, axis=0)

            # seg2cls
            save_cf_png_dir = os.path.join(opt.logs, "cf", "eval_seg_cf.png")
            save_metric_dir = os.path.join(opt.logs, "eval_metric_seg.txt")

            result_str = get_results(val_gt, val_seg_pred, val_seg_probs, save_cf_png_dir, save_metric_dir)
            logger.info("tgt_seg_val:[seg2cls]: %s" % (result_str))

        if opt.do_seg and opt.use_aux:
            val_seg_pred_au = np.concatenate(val_seg_pred_au, axis=0)
            val_seg_probs_au = np.concatenate(val_seg_probs_au, axis=0)

            # seg2cls
            save_cf_png_dir_au = os.path.join(opt.logs, "cf", "eval_seg_au_cf.png")
            save_metric_dir_au = os.path.join(opt.logs, "eval_metric_seg_au.txt")

            result_str_au = get_results(val_gt, val_seg_pred_au, val_seg_probs_au, save_cf_png_dir_au, save_metric_dir_au)
            logger.info("tgt_seg_au_val:[seg2cls]: %s" % (result_str_au))

    time_elapsed = time.time() - since
    logger.info("Eval complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
Exemplo n.º 8
0
def train_model(model,
                dataloaders,
                criterion,
                optimizer,
                num_epochs=25,
                is_inception=False,
                log_dir="./log/",
                scheduler=None,
                writer=None,
                logger=None,
                opt=None):
    print(opt)
    since = time.time()
    val_acc_history = []

    best_acc = 0.0

    batch_size = dataloaders['train'].batch_size
    print_step = 5  # print info per 10 batches

    val_losses = []

    tgt_cls_train_iter = iter(dataloaders['tgt_cls_train'])

    for epoch in range(num_epochs):
        logger.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
        logger.info('-' * 10)

        learning_rate = get_learning_rate(optimizer)
        writer.add_scalar("lr", learning_rate, epoch)

        epoch_val_preds = []
        epoch_val_y = []

        epoch_train_preds = []
        epoch_train_y = []

        # Each epoch has a training and validation phase

        model.train()  # Set model to training mode
        running_loss = 0.0

        # Iterate over data.
        for batch_idx, (inputs, labels) in enumerate(dataloaders["train"], 0):

            inputs = inputs.to(device)

            # adjust labels
            labels[labels == opt.drr_mask_value_dict["lung"]] = 1
            labels[labels == opt.drr_mask_value_dict["infection"]] = 2

            labels = labels[:, -1].to(device)
            tag_labels = ((labels == 2).sum(-1).sum(-1) > 0).type(
                torch.long).to(device)  # batch_size, 1

            c_labels = tag_labels if opt.do_cls_mmd else None
            s_labels = labels if opt.do_seg_mmd else None

            if opt.do_cls_mmd or opt.do_seg_mmd:
                # tgt_cls
                try:
                    tgt_inputs, _ = tgt_cls_train_iter.next()
                except StopIteration:
                    tgt_cls_train_iter = iter(dataloaders['tgt_cls_train'])
                    tgt_inputs, _ = tgt_cls_train_iter.next()

                tgt_inputs = tgt_inputs.to(device)
            else:
                tgt_inputs = None

            # zero the parameter gradients
            optimizer.zero_grad()
            model.zero_grad()
            # forward
            # track history if only in train
            with torch.set_grad_enabled(True):
                src_cls_logits, loss_cls_lmmd, src_seg_logits, loss_seg_lmmd, _ = model(
                    inputs,
                    tgt_img=tgt_inputs,
                    c_label=c_labels,
                    s_label=s_labels)

                lambd = 2 / (1 + math.exp(-10 * (epoch) / num_epochs)) - 1

                if opt.do_cls and opt.do_cls_mmd:
                    loss_cls_lmmd = lambd * loss_cls_lmmd * opt.lambda_cls_mmd
                    loss_cls_lmmd_item = loss_cls_lmmd.item()
                else:
                    loss_cls_lmmd = 0
                    loss_cls_lmmd_item = 0

                if opt.do_seg and opt.do_seg_mmd:
                    loss_seg_lmmd = lambd * loss_seg_lmmd * opt.lmabda_seg_mmd
                    loss_seg_lmmd_item = loss_seg_lmmd.item()
                else:
                    loss_seg_lmmd = 0
                    loss_seg_lmmd_item = 0

                if opt.do_seg:
                    loss_seg = criterion(
                        labels,
                        src_seg_logits,
                        class_weights=opt.seg_class_weights) * opt.lambda_seg
                    loss_seg_item = loss_seg.item()
                else:
                    loss_seg = 0
                    loss_seg_item = 0

                if opt.do_cls:
                    loss_cls = F.cross_entropy(src_cls_logits,
                                               tag_labels) * opt.lambda_cls
                    loss_cls_item = loss_cls.item()
                else:
                    loss_cls = 0
                    loss_cls_item = 0

                loss = loss_seg + loss_cls + loss_seg_lmmd + loss_cls_lmmd
                loss_item = loss.item()

                loss.backward()
                optimizer.step()

            # statistics
            if batch_idx % print_step == 0:  # print info
                print_loss = running_loss / ((batch_idx + 1) * batch_size)
                logger.info(
                    "Train E{:>03} B{:>05} LR:{:.8f} Loss: {:.4f} LSeg: {:.4f} SegMmd: {:.4f} LCls: {:.4f} ClsMmd: {:.4f}"
                    .format(epoch, batch_idx, learning_rate, loss_item,
                            loss_seg_item, loss_seg_lmmd_item, loss_cls_item,
                            loss_cls_lmmd_item))

        scheduler.step()
        weight_path = os.path.join(log_dir, "latest.pth")
        torch.save(model.state_dict(), weight_path)

        if ((epoch + 1) % opt.eval_times == 0
                or epoch + 1 == num_epochs) and opt.do_seg:

            # eval lung segmentation
            logger.info("-" * 8 + "eval lung segmentation" + "-" * 8)

            model.eval()
            all_dices = []

            for batch_idx, (inputs, labels) in enumerate(
                    dataloaders["tgt_lung_seg_val"], 0):
                annotation = dataloaders[
                    "tgt_lung_seg_val"].dataset.annotations[batch_idx]
                img_dir = annotation.strip().split(',')[0]
                img_name = Path(img_dir).name

                inputs = inputs.to(device)
                # adjust labels

                labels[labels == opt.xray_mask_value_dict["lung"]] = 1

                labels = labels[:, -1].to(device)
                labels = torch.stack([labels == c for c in range(2)], dim=1)

                with torch.set_grad_enabled(False):
                    _, _, seg_logits, _, _ = model(inputs)
                    seg_probs = torch.softmax(seg_logits, dim=1)
                    predicted_mask = probs2one_hot(seg_probs.detach())

                    # change the infection to Lung
                    predicted_mask_lung = predicted_mask[:, :-1]
                    predicted_mask_lung[:, -1] += predicted_mask[:, -1]
                    dices = dice_coef(
                        predicted_mask_lung,
                        labels.detach().type_as(predicted_mask)).cpu().numpy()

                    all_dices.append(dices)  # [(B,C)]

                    predicted_mask_lung = predicted_mask_lung.squeeze().cpu(
                    ).numpy()  # 3xwxh
                    mask_inone = (np.zeros_like(predicted_mask_lung[0]) +
                                  predicted_mask_lung[1] * 255).astype(
                                      np.uint8)

                    # save dir:
                    save_dir = os.path.join(opt.logs, "tgt_lung_seg_val",
                                            "ep%03d" % epoch)
                    #
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)

                    cv2.imwrite(os.path.join(save_dir, img_name), mask_inone)

                avg_dice = np.mean(np.concatenate(all_dices, 0), 0)  #

                logger.info(
                    "tgt_lung_seg_val:EP%03d,[%d/%d],dice0:%.03f,dice1:%.03f,dice:%.03f"
                    % (epoch, batch_idx,
                       len(dataloaders['tgt_lung_seg_val'].dataset) //
                       inputs.shape[0], avg_dice[0], avg_dice[1],
                       np.mean(np.concatenate(all_dices, 0))))

        if ((epoch + 1) % opt.eval_cls_times == 0 or epoch + 1 == num_epochs):
            # eval infection segmentation and cls
            logger.info("-" * 8 + "eval infection cls" + "-" * 8)
            model.eval()

            val_gt = []
            val_cls_pred = []
            val_seg_pred = []

            for batch_idx, (inputs,
                            labels) in enumerate(dataloaders["tgt_cls_val"],
                                                 0):
                inputs = inputs.to(device)
                # adjust label
                val_gt.append(labels.cpu().data.numpy())

                with torch.set_grad_enabled(False):
                    annotation = dataloaders[
                        "tgt_cls_val"].dataset.annotations[batch_idx]
                    img_dir = annotation.strip().split(',')[0]
                    img_name = Path(img_dir).name

                    cls_logits, _, seg_logits, _, _ = model(inputs)
                    if opt.do_seg:
                        seg_probs = torch.softmax(seg_logits, dim=1)
                        predicted_mask_onehot = probs2one_hot(
                            seg_probs.detach())

                        # for save
                        predicted_mask = predicted_mask_onehot.squeeze().cpu(
                        ).numpy()  # 3xwxh
                        mask_inone = (np.zeros_like(predicted_mask[0]) +
                                      predicted_mask[1] * 128 +
                                      predicted_mask[2] * 255).astype(np.uint8)

                        # save dir:
                        save_dir = os.path.join(opt.logs, "tgt_cls_val",
                                                "ep%03d" % epoch)
                        #
                        if not os.path.exists(save_dir):
                            os.makedirs(save_dir)

                        cv2.imwrite(os.path.join(save_dir, img_name),
                                    mask_inone)
                        # seg2cls
                        preds_cls_seg = (
                            predicted_mask_onehot[:, -1:].sum(-1).sum(-1) >
                            0).cpu().numpy().astype(np.uint8)
                        val_seg_pred.append(preds_cls_seg)

                    # cls
                    #print(cls_logits)
                    if opt.do_cls:
                        probs_cls = torch.softmax(cls_logits, dim=1)
                        preds_cls = (probs_cls[..., 1:] > 0.5).type(torch.long)
                        val_cls_pred.append(preds_cls.cpu().data.numpy())

            if not os.path.exists(os.path.join(opt.logs, "cf")):
                os.makedirs(os.path.join(opt.logs, "cf"))

            val_gt = np.concatenate(val_gt, axis=0)

            if opt.do_cls:
                val_cls_pred = np.concatenate(val_cls_pred, axis=0)
                save_cf_png_dir = os.path.join(opt.logs, "cf",
                                               "ep%03d_cls_cf.png" % epoch)
                save_metric_dir = os.path.join(opt.logs, "metric_cls.txt")
                result_str = get_results(val_gt, val_cls_pred, save_cf_png_dir,
                                         save_metric_dir)
                logger.info("tgt_cls_val:EP%03d,[cls]: %s" %
                            (epoch, result_str))

            if opt.do_seg:
                val_seg_pred = np.concatenate(val_seg_pred, axis=0)
                # seg2cls
                save_cf_png_dir = os.path.join(opt.logs, "cf",
                                               "ep%03d_seg_cf.png" % epoch)
                save_metric_dir = os.path.join(opt.logs, "metric_seg.txt")

                result_str = get_results(val_gt, val_seg_pred, save_cf_png_dir,
                                         save_metric_dir)
                logger.info("tgt_seg_val:EP%03d,[seg2cls]: %s" %
                            (epoch, result_str))

    time_elapsed = time.time() - since
    logger.info("Training complete in {:.0f}m {:.0f}s".format(
        time_elapsed // 60, time_elapsed % 60))
Exemplo n.º 9
0
def do_epoch(args, mode: str, net: Any, device: Any, loader: DataLoader, epc: int,
             loss_fns: List[Callable], loss_weights: List[float],loss_fns_source: List[Callable],
             loss_weights_source: List[float], new_w:int, num_steps:int, C: int, metric_axis:List[int], savedir: str = "",
             optimizer: Any = None, target_loader: Any = None):

    assert mode in ["train", "val"]
    L: int = len(loss_fns)
    indices = torch.tensor(metric_axis,device=device)
    if mode == "train":
        net.train()
        desc = f">> Training   ({epc})"
    elif mode == "val":
        net.eval()
        # net.train()
        desc = f">> Validation ({epc})"

    total_it_s, total_images = len(loader), len(loader.dataset)
    total_it_t, total_images_t = len(target_loader), len(target_loader.dataset)
    total_iteration = max(total_it_s, total_it_t)
    # Lazy add lines below because we will be cycling until the biggest length is reached
    total_images = max(total_images, total_images_t)
    total_images_t = total_images

    pho=1
    dtype = eval(args.dtype)

    all_dices: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_inter_card: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_card_gt: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_card_pred: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    loss_log: Tensor = torch.zeros((total_images), dtype=dtype, device=device)
    loss_inf: Tensor = torch.zeros((total_images), dtype=dtype, device=device)
    loss_cons: Tensor = torch.zeros((total_images), dtype=dtype, device=device)
    loss_fs: Tensor = torch.zeros((total_images), dtype=dtype, device=device)
    posim_log: Tensor = torch.zeros((total_images), dtype=dtype, device=device)
    haussdorf_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_grp: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    dice_3d_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    dice_3d_sd_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)

    if args.source_metrics == True:
    	all_dices_s: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    	all_inter_card_s: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    	all_card_gt_s: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    	all_card_pred_s: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    	all_grp_s: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    	dice_3d_s_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    	dice_3d_s_sd_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    # if len(loader)>len(target_loader):
    #     tq_iter = tqdm_(enumerate(zip(loader, cycle(target_loader))), total=total_iteration, desc=desc)
    # elif len(loader)<len(target_loader):
    #     tq_iter = tqdm_(enumerate(zip(cycle(loader), target_loader)), total=total_iteration, desc=desc)
    # else:
    #     tq_iter = tqdm_(enumerate(zip(loader, target_loader)), total=total_iteration, desc=desc)
    tq_iter = tqdm_(enumerate(zip(loader, target_loader)), total=total_iteration, desc=desc)
    #tq_iter = tqdm_(enumerate(target_loader), total=total_iteration, desc=desc)
    done: int = 0
    #ratio_losses = 0
    n_warmup = 0
    mult_lw = [pho ** (epc - n_warmup + 1)] * len(loss_weights)
    #if epc > 100:
    #    mult_lw = [pho ** 100] * len(loss_weights)
    mult_lw[0] = 1
    loss_weights = [a * b for a, b in zip(loss_weights, mult_lw)]
    losses_vec, source_vec, target_vec, baseline_target_vec = [], [], [], []
    pen_count = 0
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for j, (source_data, target_data) in tq_iter:
        #for j, target_data in tq_iter:
            source_data[1:] = [e.to(device) for e in source_data[1:]]  # Move all tensors to device
            filenames_source, source_image, source_gt = source_data[:3]
            target_data[1:] = [e.to(device) for e in target_data[1:]]  # Move all tensors to device
            filenames_target, target_image, target_gt = target_data[:3]
            labels = target_data[3:3+L]
            labels_source = source_data[3:3 + L]
            bounds = target_data[3+L:]
            bounds_source = source_data[3+L:]
            
            assert len(labels) == len(bounds), len(bounds)
            if args.mix == False:
            	assert filenames_source == filenames_target
            #print(filenames_source,filenames_target)
            B = len(target_image)
            # Reset gradients
            if optimizer:
                #adjust_learning_rate(optimizer, 1, args.l_rate, num_steps, args.power)
                optimizer.zero_grad()

            # Forward
            with torch.set_grad_enabled(mode == "train"):
                pred_logits: Tensor = net(target_image)
                pred_logits_source: Tensor = net(source_image)
                pred_probs: Tensor = F.softmax(pred_logits, dim=1)
                pred_probs_source: Tensor = F.softmax(pred_logits_source, dim=1)
                if new_w > 0:
                    pred_probs = resize(pred_probs, new_w)
                    labels = [resize(label, new_w) for label in labels]
                    target = resize(target, new_w)
                predicted_mask: Tensor = probs2one_hot(pred_probs)  # Used only for dice computation
                predicted_mask_source: Tensor = probs2one_hot(pred_probs_source)  # Used only for dice computation
                #print(torch.sum(predicted_mask, dim=[2,3]).cpu().numpy())     
                #print(list(map(lambda n: [int(f) for f in n], np.around(torch.sum(pred_probs, dim=[2,3]).detach().cpu().numpy()))))     
            assert len(bounds) == len(loss_fns) == len(loss_weights)
            if epc < n_warmup:
                loss_weights = [0]*len(loss_weights)
            loss: Tensor = torch.zeros(1, requires_grad=True).to(device)
            loss_vec = []
            loss_k = []
            for loss_fn,label, w, bound in zip(loss_fns,labels, loss_weights, bounds):
                if w > 0:
                    if args.lin_aug_w:
                        if epc<70:
                            w=w*(epc+1)/70
                    loss_b =  loss_fn(pred_probs, label, bound)
                    loss = loss + w * loss_b
                    #pen_count += count_b.detach()
                    #print(count_b.detach())
                    loss_k.append(w*loss_b.detach())
            #for loss_fn, label, w, bound in zip(loss_fns_source, [source_gt], loss_weights_source, torch.randn(1)):
            #for loss_fn, label, w, bound in zip(loss_fns_source, labels_source, loss_weights_source, torch.randn(1)):
            for loss_fn, label, w, bound in zip(loss_fns_source, labels_source, loss_weights_source, bounds_source):
                if w > 0:
                    loss_b =  loss_fn(pred_probs_source, label, bound)
                    loss = loss+ w * loss_b
                    loss_k.append(w*loss_b.detach())
            #print(loss_k)
            # Backward
            if optimizer:
                loss.backward()
                optimizer.step()

            # Compute and log metrics
            #dices: Tensor = dice_coef(predicted_mask.detach(), target.detach())
            # baseline_dices: Tensor = dice_coef(labels[0].detach(), target.detach())
            #batch_dice: Tensor = dice_batch(predicted_mask.detach(), target.detach())
            # assert batch_dice.shape == (C,) and dices.shape == (B, C), (batch_dice.shape, dices.shape, B, C)
            dices, inter_card, card_gt, card_pred = dice_coef(predicted_mask.detach(), target_gt.detach())
            assert dices.shape == (B, C), (dices.shape, B, C)
            
            sm_slice = slice(done, done + B)  # Values only for current batch
            all_dices[sm_slice, ...] = dices
            # # for 3D dice
            all_grp[sm_slice, ...] = int(re.split('_', filenames_target[0])[1]) * torch.ones([1, C])
            all_inter_card[sm_slice, ...] = inter_card
            all_card_gt[sm_slice, ...] = card_gt
            all_card_pred[sm_slice, ...] = card_pred

            # 3D dice on source
            if args.source_metrics ==True:
            	dices_s, inter_card_s, card_gt_s, card_pred_s = dice_coef(predicted_mask_source.detach(), source_gt.detach())
            	all_grp_s[sm_slice, ...] = int(re.split('_', filenames_source[0])[1]) * torch.ones([1, C])
            	all_inter_card_s[sm_slice, ...] = inter_card_s
            	all_card_gt_s[sm_slice, ...] = card_gt_s
            	all_card_pred_s[sm_slice, ...] = card_pred_s

            #loss_log[sm_slice] = loss.detach()

            loss_inf[sm_slice] = loss_k[0]
            if len(loss_k)>1:
            	loss_cons[sm_slice] = loss_k[1]
            else:
            	loss_cons[sm_slice] = 0
            if len(loss_k)>2:
            	loss_fs[sm_slice] = loss_k[2]
            else:
            	loss_fs[sm_slice] = 0
            #posim_log[sm_slice] = torch.einsum("bcwh->b", [target_gt[:, 1:, :, :]]).detach() > 0
            
            #haussdorf_res: Tensor = haussdorf(predicted_mask.detach(), target_gt.detach(), dtype)
            #assert haussdorf_res.shape == (B, C)
            #haussdorf_log[sm_slice] = haussdorf_res
            
            # # Save images
            if savedir:
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    warnings.simplefilter("ignore") 
                    predicted_class: Tensor = probs2class(pred_probs)
                    #save_images_p(predicted_class, filenames_target, args.dataset, mode, epc, False)
                    save_images(predicted_class, filenames_target, savedir, mode, epc, True)
          
            # Logging
            big_slice = slice(0, done + B)  # Value for current and previous batches
            stat_dict = {"dice": torch.index_select(all_dices, 1, indices).mean(),
                         "loss": loss_log[big_slice].mean()}
            nice_dict = {k: f"{v:.4f}" for (k, v) in stat_dict.items()}
            
            done += B
            tq_iter.set_postfix(nice_dict)
        print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))
    #dice_posim = torch.masked_select(all_dices[:, -1], posim_log.type(dtype=torch.uint8)).mean()
    # dice3D gives back the 3d dice mai on images
    # if not args.debug:
    #    dice_3d_log_o, dice_3d_sd_log_o = dice3d(args.workdir, f"iter{epc:03d}", mode, "Subj_\\d+_",args.dataset + mode + '/CT_GT', C)

    dice_3d_log, dice_3d_sd_log = dice3dn(all_grp, all_inter_card, all_card_gt, all_card_pred,metric_axis,True)
    if args.source_metrics ==True:
        dice_3d_s_log, dice_3d_s_sd_log = dice3dn(all_grp_s, all_inter_card_s, all_card_gt_s, all_card_pred_s,metric_axis,True)
    print("mean 3d_dice over all patients:",dice_3d_log)
    #source_vec = [ dice_3d_s, dice_3d_sd_s, haussdorf_log_s]
    dice_2d = torch.index_select(all_dices, 1, indices).mean().cpu().numpy()
    target_vec = [ dice_3d_log, dice_3d_sd_log, dice_2d]
    if args.source_metrics ==True:
        source_vec = [ dice_3d_s_log, dice_3d_s_sd_log]
    else:
        source_vec = [0,0]
    #losses_vec = [loss_log.mean().item()]
    losses_vec = [loss_inf.mean().item(),loss_cons.mean().item(),loss_fs.mean().item()]
    return losses_vec, target_vec,source_vec
def do_epoch(mode: str, net: Any, device: Any, loader: DataLoader, epc: int,
             loss_fns: List[Callable], loss_weights: List[float], C: int,
             savedir: str = "", optimizer: Any = None,
             metric_axis: List[int] = [1], compute_haussdorf: bool = False) \
        -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    assert mode in ["train", "val"]
    L: int = len(loss_fns)

    if mode == "train":
        net.train()
        desc = f">> Training   ({epc})"
    elif mode == "val":
        net.eval()
        desc = f">> Validation ({epc})"

    total_iteration, total_images = len(loader), len(loader.dataset)
    all_dices: Tensor = torch.zeros((total_images, C),
                                    dtype=torch.float32,
                                    device=device)
    batch_dices: Tensor = torch.zeros((total_iteration, C),
                                      dtype=torch.float32,
                                      device=device)
    loss_log: Tensor = torch.zeros((total_iteration),
                                   dtype=torch.float32,
                                   device=device)
    haussdorf_log: Tensor = torch.zeros((total_images, C),
                                        dtype=torch.float32,
                                        device=device)

    tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc)
    done: int = 0
    for j, data in tq_iter:
        data[1:] = [e.to(device)
                    for e in data[1:]]  # Move all tensors to device
        filenames, image, target = data[:3]
        labels = data[3:3 + L]
        bounds = data[3 + L:]
        assert len(labels) == len(bounds)

        B = len(image)

        # Reset gradients
        if optimizer:
            optimizer.zero_grad()

        # Forward
        pred_logits: Tensor = net(image)
        pred_probs: Tensor = F.softmax(pred_logits, dim=1)
        predicted_mask: Tensor = probs2one_hot(
            pred_probs.detach())  # Used only for dice computation

        assert len(bounds) == len(loss_fns) == len(loss_weights)
        ziped = zip(loss_fns, labels, loss_weights, bounds)
        losses = [
            w * loss_fn(pred_probs, label, bound)
            for loss_fn, label, w, bound in ziped
        ]
        loss = reduce(add, losses)
        assert loss.shape == (), loss.shape

        # Backward
        if optimizer:
            loss.backward()
            optimizer.step()

        # Compute and log metrics
        loss_log[j] = loss.detach()

        sm_slice = slice(done, done + B)  # Values only for current batch

        dices: Tensor = dice_coef(predicted_mask, target.detach())
        assert dices.shape == (B, C), (dices.shape, B, C)
        all_dices[sm_slice, ...] = dices

        if B > 1 and mode == "val":
            batch_dice: Tensor = dice_batch(predicted_mask, target.detach())
            assert batch_dice.shape == (C, ), (batch_dice.shape, B, C)
            batch_dices[j] = batch_dice

        if compute_haussdorf:
            haussdorf_res: Tensor = haussdorf(predicted_mask.detach(),
                                              target.detach())
            assert haussdorf_res.shape == (B, C)
            haussdorf_log[sm_slice] = haussdorf_res

        # Save images
        if savedir:
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=UserWarning)
                predicted_class: Tensor = probs2class(pred_probs)
                save_images(predicted_class, filenames, savedir, mode, epc)

        # Logging
        big_slice = slice(0,
                          done + B)  # Value for current and previous batches

        dsc_dict = {
            f"DSC{n}": all_dices[big_slice, n].mean()
            for n in metric_axis
        }
        hauss_dict = {
            f"HD{n}": haussdorf_log[big_slice, n].mean()
            for n in metric_axis
        } if compute_haussdorf else {}
        batch_dict = {
            f"bDSC{n}": batch_dices[:j, n].mean()
            for n in metric_axis
        } if B > 1 and mode == "val" else {}

        mean_dict = {
            "DSC": all_dices[big_slice, metric_axis].mean(),
            "HD": haussdorf_log[big_slice, metric_axis].mean()
        } if len(metric_axis) > 1 else {}

        stat_dict = {
            **dsc_dict,
            **hauss_dict,
            **mean_dict,
            **batch_dict, "loss": loss_log[:j].mean()
        }
        nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()}

        tq_iter.set_postfix(nice_dict)
        done += B
    print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))

    return loss_log, all_dices, batch_dices, haussdorf_log
Exemplo n.º 11
0
def do_epoch(mode: str, net: Any, device: Any, loaders: List[DataLoader], epc: int,
             list_loss_fns: List[List[Callable]], list_loss_weights: List[List[float]], C: int,
             savedir: str = "", optimizer: Any = None,
             metric_axis: List[int] = [1], compute_haussdorf: bool = False, compute_miou: bool = False,
             temperature: float = 1) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tuple[None, Tensor]]:
    assert mode in ["train", "val"]

    if mode == "train":
        net.train()
        desc = f">> Training   ({epc})"
    elif mode == "val":
        net.eval()
        desc = f">> Validation ({epc})"

    total_iteration: int = sum(len(loader) for loader in loaders)  # U
    total_images: int = sum(len(loader.dataset) for loader in loaders)  # D
    n_loss: int = max(map(len, list_loss_fns))

    all_dices: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device)
    batch_dices: Tensor = torch.zeros((total_iteration, C), dtype=torch.float32, device=device)
    loss_log: Tensor = torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device)
    haussdorf_log: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device)
    iiou_log: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device)
    intersections: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device)
    unions: Tensor = torch.zeros((total_images, C), dtype=torch.float32, device=device)

    few_axis: bool = len(metric_axis) <= 3

    done_img: int = 0
    done_batch: int = 0
    tq_iter = tqdm_(total=total_iteration, desc=desc)
    for i, (loader, loss_fns, loss_weights) in enumerate(zip(loaders, list_loss_fns, list_loss_weights)):
        L: int = len(loss_fns)

        for data in loader:
            data[1:] = [e.to(device) for e in data[1:]]  # Move all tensors to device
            filenames, image, target = data[:3]
            assert not target.requires_grad
            labels = data[3:3 + L]
            bounds = data[3 + L:]
            assert len(labels) == len(bounds)

            B = len(image)

            # Reset gradients
            if optimizer:
                optimizer.zero_grad()

            # Forward
            pred_logits: Tensor = net(image)
            pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1)
            predicted_mask: Tensor = probs2one_hot(pred_probs.detach())  # Used only for dice computation
            assert not predicted_mask.requires_grad

            assert len(bounds) == len(loss_fns) == len(loss_weights) == len(labels)
            ziped = zip(loss_fns, labels, loss_weights, bounds)
            losses = [w * loss_fn(pred_probs, label, bound) for loss_fn, label, w, bound in ziped]
            loss = reduce(add, losses)
            assert loss.shape == (), loss.shape

            # if epc >= 1 and False:
            #     import matplotlib.pyplot as plt
            #     _, axes = plt.subplots(nrows=1, ncols=3)
            #     axes[0].imshow(image[0, 0].cpu().numpy(), cmap='gray')
            #     axes[0].contour(target[0, 1].cpu().numpy(), cmap='rainbow')

            #     pred_np = pred_probs[0, 1].detach().cpu().numpy()
            #     axes[1].imshow(pred_np)

            #     bins = np.linspace(0, 1, 50)
            #     axes[2].hist(pred_np.flatten(), bins)
            #     print(bounds)
            #     print(bounds[2].cpu().numpy())
            #     print(bounds[2][0, 1].cpu().numpy())
            #     print(pred_np.sum())
            #     plt.show()

            # Backward
            if optimizer:
                loss.backward()
                optimizer.step()

            # Compute and log metrics
            # loss_log[done_batch] = loss.detach()
            for j in range(len(loss_fns)):
                loss_log[done_batch, j] = losses[j].detach()

            sm_slice = slice(done_img, done_img + B)  # Values only for current batch

            dices: Tensor = dice_coef(predicted_mask, target)
            assert dices.shape == (B, C), (dices.shape, B, C)
            all_dices[sm_slice, ...] = dices

            if B > 1 and mode == "val":
                batch_dice: Tensor = dice_batch(predicted_mask, target)
                assert batch_dice.shape == (C,), (batch_dice.shape, B, C)
                batch_dices[done_batch] = batch_dice

            if compute_haussdorf:
                haussdorf_res: Tensor = haussdorf(predicted_mask, target)
                assert haussdorf_res.shape == (B, C)
                haussdorf_log[sm_slice] = haussdorf_res
            if compute_miou:
                IoUs: Tensor = iIoU(predicted_mask, target)
                assert IoUs.shape == (B, C), IoUs.shape
                iiou_log[sm_slice] = IoUs
                intersections[sm_slice] = inter_sum(predicted_mask, target)
                unions[sm_slice] = union_sum(predicted_mask, target)

            # Save images
            if savedir:
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    predicted_class: Tensor = probs2class(pred_probs)
                    save_images(predicted_class, filenames, savedir, mode, epc)

            # Logging
            big_slice = slice(0, done_img + B)  # Value for current and previous batches

            dsc_dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} if few_axis else {}

            hauss_dict = {f"HD{n}": haussdorf_log[big_slice, n].mean() for n in metric_axis} \
                if compute_haussdorf and few_axis else {}

            batch_dict = {f"bDSC{n}": batch_dices[:done_batch, n].mean() for n in metric_axis} \
                if B > 1 and mode == "val" and few_axis else {}

            miou_dict = {f"iIoU": iiou_log[big_slice, metric_axis].mean(),
                         f"mIoU": (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)).mean()} \
                if compute_miou else {}

            if len(metric_axis) > 1:
                mean_dict = {"DSC": all_dices[big_slice, metric_axis].mean()}
                if compute_haussdorf:
                    mean_dict["HD"] = haussdorf_log[big_slice, metric_axis].mean()
            else:
                mean_dict = {}

            stat_dict = {**miou_dict, **dsc_dict, **hauss_dict, **mean_dict, **batch_dict,
                         "loss": loss_log[:done_batch].mean()}
            nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()}

            done_img += B
            done_batch += 1
            tq_iter.set_postfix({**nice_dict, "loader": str(i)})
            tq_iter.update(1)
    tq_iter.close()
    print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))

    if compute_miou:
        mIoUs: Tensor = (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10))
        assert mIoUs.shape == (C,), mIoUs.shape
    else:
        mIoUs = None

    if not few_axis and False:
        print(f"DSC: {[f'{all_dices[:, n].mean():.3f}' for n in metric_axis]}")
        print(f"iIoU: {[f'{iiou_log[:, n].mean():.3f}' for n in metric_axis]}")
        if mIoUs:
            print(f"mIoU: {[f'{mIoUs[n]:.3f}' for n in metric_axis]}")

    return loss_log, all_dices, batch_dices, haussdorf_log, mIoUs
Exemplo n.º 12
0
def do_epoch(
    mode: str,
    net: Any,
    device: Any,
    loaders: List[DataLoader],
    epc: int,
    list_loss_fns: List[List[Callable]],
    list_loss_weights: List[List[float]],
    K: int,
    savedir: str = "",
    optimizer: Any = None,
    metric_axis: List[int] = [1],
    compute_hausdorff: bool = False,
    compute_miou: bool = False,
    compute_3d_dice: bool = False,
    temperature: float = 1
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor],
           Optional[Tensor]]:
    assert mode in ["train", "val"]

    if mode == "train":
        net.train()
        desc = f">> Training   ({epc})"
    elif mode == "val":
        net.eval()
        desc = f">> Validation ({epc})"

    total_iteration: int = sum(len(loader) for loader in loaders)  # U
    total_images: int = sum(len(loader.dataset) for loader in loaders)  # D
    n_loss: int = max(map(len, list_loss_fns))

    all_dices: Tensor = torch.zeros((total_images, K),
                                    dtype=torch.float32,
                                    device=device)
    loss_log: Tensor = torch.zeros((total_iteration, n_loss),
                                   dtype=torch.float32,
                                   device=device)

    iiou_log: Optional[Tensor]
    intersections: Optional[Tensor]
    unions: Optional[Tensor]
    if compute_miou:
        iiou_log = torch.zeros((total_images, K),
                               dtype=torch.float32,
                               device=device)
        intersections = torch.zeros((total_images, K),
                                    dtype=torch.float32,
                                    device=device)
        unions = torch.zeros((total_images, K),
                             dtype=torch.float32,
                             device=device)
    else:
        iiou_log = None
        intersections = None
        unions = None

    three_d_dices: Optional[Tensor]
    if compute_3d_dice:
        three_d_dices = torch.zeros((total_iteration, K),
                                    dtype=torch.float32,
                                    device=device)
    else:
        three_d_dices = None

    hausdorff_log: Optional[Tensor]
    if compute_hausdorff:
        hausdorff_log = torch.zeros((total_images, K),
                                    dtype=torch.float32,
                                    device=device)
    else:
        hausdorff_log = None

    few_axis: bool = len(metric_axis) <= 3

    done_img: int = 0
    done_batch: int = 0
    tq_iter = tqdm_(total=total_iteration, desc=desc)
    for i, (loader, loss_fns, loss_weights) in enumerate(
            zip(loaders, list_loss_fns, list_loss_weights)):
        for data in loader:
            image: Tensor = data["images"].to(device)
            target: Tensor = data["gt"].to(device)
            spacings: Tensor = data["spacings"]  # Keep that one on CPU
            assert not target.requires_grad
            labels: List[Tensor] = [e.to(device) for e in data["labels"]]
            bounds: List[Tensor] = [e.to(device) for e in data["bounds"]]
            box_priors: List[List[Tuple[
                Tensor, Tensor]]]  # one more level for the batch
            box_priors = [[(m.to(device), b.to(device)) for (m, b) in B]
                          for B in data["box_priors"]]
            assert len(labels) == len(bounds)

            B, C, *_ = image.shape

            samplings: List[List[Tuple[slice]]] = data["samplings"]
            assert len(samplings) == B
            assert len(samplings[0][0]) == len(
                image[0, 0].shape), (samplings[0][0], image[0, 0].shape)

            probs_receptacle: Tensor = -torch.ones_like(
                target, dtype=torch.float32)  # -1 for unfilled
            mask_receptacle: Tensor = -torch.ones_like(
                target, dtype=torch.int32)  # -1 for unfilled

            # Use the sampling coordinates of the first batch item
            assert not (len(samplings[0]) > 1 and
                        B > 1), samplings  # No subsampling if batch size > 1
            loss_sub_log: Tensor = torch.zeros(
                (len(samplings[0]), len(loss_fns)),
                dtype=torch.float32,
                device=device)
            for k, sampling in enumerate(samplings[0]):
                img_sampling = [slice(0, B), slice(0, C)] + list(sampling)
                label_sampling = [slice(0, B), slice(0, K)] + list(sampling)
                assert len(img_sampling) == len(image.shape), (img_sampling,
                                                               image.shape)
                sub_img = image[img_sampling]

                # Reset gradients
                if optimizer:
                    optimizer.zero_grad()

                # Forward
                pred_logits: Tensor = net(sub_img)
                pred_probs: Tensor = F.softmax(temperature * pred_logits,
                                               dim=1)
                predicted_mask: Tensor = probs2one_hot(
                    pred_probs.detach())  # Used only for dice computation
                assert not predicted_mask.requires_grad

                probs_receptacle[label_sampling] = pred_probs[...]
                mask_receptacle[label_sampling] = predicted_mask[...]

                assert len(bounds) == len(loss_fns) == len(
                    loss_weights) == len(labels)
                ziped = zip(loss_fns, labels, loss_weights, bounds)
                losses = [
                    w * loss_fn(pred_probs, label[label_sampling], bound,
                                box_priors)
                    for loss_fn, label, w, bound in ziped
                ]
                loss = reduce(add, losses)
                assert loss.shape == (), loss.shape

                # Backward
                if optimizer:
                    loss.backward()
                    optimizer.step()

                # Compute and log metrics
                for j in range(len(loss_fns)):
                    loss_sub_log[k, j] = losses[j].detach()
            reduced_loss_sublog: Tensor = loss_sub_log.sum(dim=0)
            assert reduced_loss_sublog.shape == (len(loss_fns), ), (
                reduced_loss_sublog.shape, len(loss_fns))
            loss_log[done_batch, ...] = reduced_loss_sublog[...]
            del loss_sub_log

            sm_slice = slice(done_img,
                             done_img + B)  # Values only for current batch

            dices: Tensor = dice_coef(mask_receptacle, target)
            assert dices.shape == (B, K), (dices.shape, B, K)
            all_dices[sm_slice, ...] = dices

            if compute_3d_dice:
                three_d_DSC: Tensor = dice_batch(mask_receptacle, target)
                assert three_d_DSC.shape == (K, )

                three_d_dices[done_batch] = three_d_DSC  # type: ignore

            if compute_hausdorff:
                hausdorff_res: Tensor
                try:
                    hausdorff_res = hausdorff(mask_receptacle, target,
                                              spacings)
                except RuntimeError:
                    hausdorff_res = torch.zeros((B, K), device=device)
                assert hausdorff_res.shape == (B, K)
                hausdorff_log[sm_slice] = hausdorff_res  # type: ignore
            if compute_miou:
                IoUs: Tensor = iIoU(mask_receptacle, target)
                assert IoUs.shape == (B, K), IoUs.shape
                iiou_log[sm_slice] = IoUs  # type: ignore
                intersections[sm_slice] = inter_sum(mask_receptacle,
                                                    target)  # type: ignore
                unions[sm_slice] = union_sum(mask_receptacle,
                                             target)  # type: ignore

            # if False and target[0, 1].sum() > 0:  # Useful template for quick and dirty inspection
            #     import matplotlib.pyplot as plt
            #     from pprint import pprint
            #     from mpl_toolkits.axes_grid1 import ImageGrid
            #     from utils import soft_length

            #     print(data["filenames"])
            #     pprint(data["bounds"])
            #     pprint(soft_length(mask_receptacle))

            #     fig = plt.figure()
            #     fig.clear()

            #     grid = ImageGrid(fig, 211, nrows_ncols=(1, 2))

            #     grid[0].imshow(data["images"][0, 0], cmap="gray")
            #     grid[0].contour(data["gt"][0, 1], cmap='jet', alpha=.75, linewidths=2)

            #     grid[1].imshow(data["images"][0, 0], cmap="gray")
            #     grid[1].contour(mask_receptacle[0, 1], cmap='jet', alpha=.75, linewidths=2)
            #     plt.show()

            # Save images
            if savedir:
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    predicted_class: Tensor = probs2class(pred_probs)
                    save_images(predicted_class, data["filenames"], savedir,
                                mode, epc)

            # Logging
            big_slice = slice(0, done_img +
                              B)  # Value for current and previous batches

            dsc_dict: Dict
            if few_axis:
                dsc_dict = {
                    **{
                        f"DSC{n}": all_dices[big_slice, n].mean()
                        for n in metric_axis
                    },
                    **({
                        f"3d_DSC{n}": three_d_dices[:done_batch, n].mean()
                        for n in metric_axis
                    } if three_d_dices is not None else {})
                }
            else:
                dsc_dict = {}

            # dsc_dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} if few_axis else {}

            hauss_dict = {f"HD{n}": hausdorff_log[big_slice, n].mean() for n in metric_axis} \
                if hausdorff_log is not None and few_axis else {}

            miou_dict = {f"iIoU": iiou_log[big_slice, metric_axis].mean(),
                         f"mIoU": (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10)).mean()} \
                if iiou_log is not None and intersections is not None and unions is not None else {}

            if len(metric_axis) > 1:
                mean_dict = {"DSC": all_dices[big_slice, metric_axis].mean()}
                if hausdorff_log:
                    mean_dict["HD"] = hausdorff_log[big_slice,
                                                    metric_axis].mean()
            else:
                mean_dict = {}

            stat_dict = {
                **miou_dict,
                **dsc_dict,
                **hauss_dict,
                **mean_dict, "loss": loss_log[:done_batch].mean()
            }
            nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()}

            done_img += B
            done_batch += 1
            tq_iter.set_postfix({**nice_dict, "loader": str(i)})
            tq_iter.update(1)
    tq_iter.close()
    print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))

    mIoUs: Optional[Tensor]
    if intersections and unions:
        mIoUs = (intersections.sum(dim=0) / (unions.sum(dim=0) + 1e-10))
        assert mIoUs.shape == (K, ), mIoUs.shape
    else:
        mIoUs = None

    if not few_axis and False:
        print(f"DSC: {[f'{all_dices[:, n].mean():.3f}' for n in metric_axis]}")
        print(f"iIoU: {[f'{iiou_log[:, n].mean():.3f}' for n in metric_axis]}")
        if mIoUs:
            print(f"mIoU: {[f'{mIoUs[n]:.3f}' for n in metric_axis]}")

    return (
        loss_log.detach().cpu(), all_dices.detach().cpu(),
        hausdorff_log.detach().cpu() if hausdorff_log is not None else None,
        mIoUs.detach().cpu() if mIoUs is not None else None,
        three_d_dices.detach().cpu() if three_d_dices is not None else None)
Exemplo n.º 13
0
def do_epoch(mode: str, args, net, device, use_cuda, loader, optimizer,
             num_classes, epoch):

    totalImages = len(loader)

    if mode == "train":
        net.train()
        desc = f">> Training   ({epoch})"
    elif mode == "val":
        net.eval()
        desc = f">> Validation ({epoch})"

    total_iteration, total_images = len(loader), len(loader.dataset)
    all_dices: Tensor = torch.zeros((total_images, num_classes),
                                    dtype=eval(args.dtype),
                                    device=device)
    batch_dices: Tensor = torch.zeros((total_iteration, num_classes),
                                      dtype=eval(args.dtype),
                                      device=device)
    loss_log: Tensor = torch.zeros((total_images),
                                   dtype=eval(args.dtype),
                                   device=device)
    entropy_log: Tensor = torch.zeros((total_images),
                                      dtype=eval(args.dtype),
                                      device=device)
    KL_log: Tensor = torch.zeros((total_images),
                                 dtype=eval(args.dtype),
                                 device=device)

    tq_iter = tqdm_(enumerate(loader), total=total_iteration, desc=desc)
    done: int = 0

    for j, data in tq_iter:

        image_f, image_i, image_d, image_o, image_w, image_c, labels, img_names = data
        #image_f=image_f.type(torch.FloatTensor)/65535.
        #image_f = image_f.type(torch.FloatTensor)/65535.
        #image_i = image_i.type(torch.FloatTensor)/65535.
        #image_d = image_d.type(torch.FloatTensor)/65535.
        #image_o = image_o.type(torch.FloatTensor)/65535.
        #image_w = image_w.type(torch.FloatTensor)/65535.
        #image_c = image_c.type(torch.FloatTensor)/65535.
        MRI: Tensor = torch.zeros((1, 6, image_f.size()[2], image_f.size()[3]),
                                  dtype=eval(args.dtype))
        MRI = torch.cat((image_f, image_i, image_d, image_o, image_w, image_c),
                        dim=1)
        MRI = MRI.type(
            torch.FloatTensor
        ) / 65535.0  #.type(eval(args.dtype)) #.type(torch.FloatTensor)
        targets = torch.cat((1 - labels, labels),
                            dim=1)  #.type(torch.LongTensor)
        B = len(image_f)
        #print(type(labels))
        #MRI = torch.cat((image_f,image_i,image_d,image_w),dim=1)
        if use_cuda:
            MRI, targets = MRI.to(device), targets.to(device)

        # forward
        outputs = net(MRI)
        pred_probs = F.softmax(outputs, dim=1)
        predicted_mask = probs2one_hot(pred_probs)

        entropy = crossEntropy_f(pred_probs, targets)

        pred_probs_aver: Tensor = torch.sum(pred_probs, dim=(2, 3))
        pred_probs_aver = pred_probs_aver / torch.sum(targets).float()
        target_aver: Tensor = torch.sum(targets, dim=(2, 3)).float()
        target_aver = target_aver / torch.sum(targets).float()
        KL_loss = args.lam * kl(target_aver, pred_probs_aver)

        loss = entropy + KL_loss

        if mode == "train":
            # zero the parameter gradients8544
            optimizer.zero_grad()
            # backward + optimize
            loss.backward()
            optimizer.step()

        # Compute and log metrics
        dices: Tensor = dice_coef(predicted_mask.detach(),
                                  targets.type(torch.cuda.IntTensor).detach())
        batch_dice: Tensor = dice_batch(
            predicted_mask.detach(),
            targets.type(torch.cuda.IntTensor).detach())
        assert batch_dice.shape == (num_classes, ) and dices.shape == (
            B, num_classes), (batch_dice.shape, dices.shape, B, num_classes)

        sm_slice = slice(done, done + B)  # Values only for current batch
        all_dices[sm_slice, ...] = dices
        entropy_log[sm_slice] = entropy.detach()
        loss_log[sm_slice] = loss.detach()
        KL_log[sm_slice] = KL_loss.detach()
        batch_dices[j] = batch_dice

        # Logging
        big_slice = slice(0,
                          done + B)  # Value for current and previous batches
        stat_dict = {
            "dice": all_dices[big_slice, -1].mean(),
            "total loss": loss_log[big_slice].mean(),
            "entropy loss": entropy_log[big_slice].mean(),
            "KL loss": KL_log[big_slice].mean(),
            "b dice": batch_dices[:j + 1, -1].mean()
        }
        nice_dict = {k: f"{v:.4f}" for (k, v) in stat_dict.items()}

        done += B
        tq_iter.set_postfix(nice_dict)

    return loss_log, entropy_log, KL_log, all_dices, batch_dices
Exemplo n.º 14
0
def do_epoch(args, mode: str, net: Any, device: Any, epc: int,
             loss_fns: List[Callable], loss_weights: List[float],
              new_w:int, C: int, metric_axis:List[int], savedir: str = "",
             optimizer: Any = None, target_loader: Any = None, best_dice3d_val:Any=None):

    assert mode in ["train", "val"]
    L: int = len(loss_fns)
    indices = torch.tensor(metric_axis,device=device)
    if mode == "train":
        net.train()
        desc = f">> Training   ({epc})"
    elif mode == "val":
        net.eval()
        # net.train()
        desc = f">> Validation ({epc})"

    total_it_t, total_images_t = len(target_loader), len(target_loader.dataset)
    total_iteration = total_it_t
    total_images = total_images_t

    if args.debug:
        total_iteration = 10
    pho=1
    dtype = eval(args.dtype)

    all_dices: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_sizes: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_sizes: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_gt_sizes: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_sizes2: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_inter_card: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_card_gt: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_card_pred: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_gt = []
    all_pred = []
    if args.do_hd: 
        all_gt: Tensor = torch.zeros((total_images, args.wh, args.wh), dtype=dtype)
        all_pred: Tensor = torch.zeros((total_images, args.wh, args.wh), dtype=dtype)
    loss_log: Tensor = torch.zeros((total_images), dtype=dtype, device=device)
    loss_cons: Tensor = torch.zeros((total_images), dtype=dtype, device=device)
    loss_se: Tensor = torch.zeros((total_images), dtype=dtype, device=device)
    loss_tot: Tensor = torch.zeros((total_images), dtype=dtype, device=device)
    posim_log: Tensor = torch.zeros((total_images), dtype=dtype, device=device)
    haussdorf_log: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_grp: Tensor = torch.zeros((total_images, C), dtype=dtype, device=device)
    all_pnames = np.zeros([total_images]).astype('U256') 
    #dice_3d_log: Tensor = torch.zeros((1, C), dtype=dtype, device=device)
    dice_3d_log, dice_3d_sd_log = 0, 0
    #dice_3d_sd_log: Tensor = torch.zeros((1, C), dtype=dtype, device=device)
    hd_3d_log, asd_3d_log, hd_3d_sd_log, asd_3d_sd_log= 0, 0, 0, 0
    tq_iter = tqdm_(enumerate(target_loader), total=total_iteration, desc=desc)
    done: int = 0
    n_warmup = args.n_warmup
    mult_lw = [pho ** (epc - n_warmup + 1)] * len(loss_weights)
    mult_lw[0] = 1
    loss_weights = [a * b for a, b in zip(loss_weights, mult_lw)]
    losses_vec, source_vec, target_vec, baseline_target_vec = [], [], [], []
    pen_count = 0
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        count_losses = 0
        for j, target_data in tq_iter:
            target_data[1:] = [e.to(device) for e in target_data[1:]]  # Move all tensors to device
            filenames_target, target_image, target_gt = target_data[:3]
            #print("target", filenames_target)
            labels = target_data[3:3+L]
            bounds = target_data[3+L:]
            filenames_target = [f.split('.nii')[0] for f in filenames_target]
            assert len(labels) == len(bounds), len(bounds)
            B = len(target_image)
            # Reset gradients
            if optimizer:
                optimizer.zero_grad()

            # Forward
            with torch.set_grad_enabled(mode == "train"):
                pred_logits: Tensor = net(target_image)
                pred_probs: Tensor = F.softmax(pred_logits, dim=1)
                if new_w > 0:
                    pred_probs = resize(pred_probs, new_w)
                    labels = [resize(label, new_w) for label in labels]
                    target = resize(target, new_w)
                predicted_mask: Tensor = probs2one_hot(pred_probs)  # Used only for dice computation
            assert len(bounds) == len(loss_fns) == len(loss_weights)
            if epc < n_warmup:
                loss_weights = [0]*len(loss_weights)
            loss: Tensor = torch.zeros(1, requires_grad=True).to(device)
            loss_vec = []
            loss_kw = []
            for loss_fn,label, w, bound in zip(loss_fns,labels, loss_weights, bounds):
                if w > 0:
                    if eval(args.target_losses)[0][0]=="EntKLProp": 
                        loss_1, loss_cons_prior,est_prop =  loss_fn(pred_probs, label, bound)
                        loss = loss_1 + loss_cons_prior 
                    else:
                        loss =  loss_fn(pred_probs, label, bound)
                        loss = w*loss
                        loss_1 = loss
                    loss_kw.append(loss_1.detach())
           # Backward
            if optimizer:
                loss.backward()
                optimizer.step()
            dices, inter_card, card_gt, card_pred = dice_coef(predicted_mask.detach(), target_gt.detach())
            assert dices.shape == (B, C), (dices.shape, B, C)
            sm_slice = slice(done, done + B)  # Values only for current batch
            all_dices[sm_slice, ...] = dices
            if eval(args.target_losses)[0][0] in ["EntKLProp"]:
                all_sizes[sm_slice, ...] = torch.round(est_prop.detach()*target_image.shape[2]*target_image.shape[3])
            all_sizes2[sm_slice, ...] = torch.sum(predicted_mask,dim=(2,3)) 
            all_gt_sizes[sm_slice, ...] = torch.sum(target_gt,dim=(2,3)) 
            all_grp[sm_slice, ...] = torch.FloatTensor(get_subj_nb(filenames_target)).unsqueeze(1).repeat(1,C)
            all_pnames[sm_slice] = filenames_target
            all_inter_card[sm_slice, ...] = inter_card
            all_card_gt[sm_slice, ...] = card_gt
            all_card_pred[sm_slice, ...] = card_pred
            if args.do_hd:
                all_pred[sm_slice, ...] = probs2class(predicted_mask[:,:,:,:]).cpu().detach()
                all_gt[sm_slice, ...] = probs2class(target_gt).detach()
            loss_se[sm_slice] = loss_kw[0]
            if len(loss_kw)>1:
            	loss_cons[sm_slice] = loss_kw[1]
            	loss_tot[sm_slice] = loss_kw[1]+loss_kw[0]
            else:
            	loss_cons[sm_slice] = 0
            	loss_tot[sm_slice] = loss_kw[0]
            # # Save images
            if savedir and args.saveim and mode =="val":
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    warnings.simplefilter("ignore") 
                    predicted_class: Tensor = probs2class(pred_probs)
                    save_images(predicted_class, filenames_target, savedir, mode, epc, False)
                    if args.entmap:
                        ent_map = torch.einsum("bcwh,bcwh->bwh", [-pred_probs, (pred_probs+1e-10).log()])
                        save_images_ent(ent_map, filenames_target, savedir,'ent_map', epc)

            # Logging
            big_slice = slice(0, done + B)  # Value for current and previous batches
            stat_dict = {**{f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis},
                         **{f"SZ{n}": all_sizes[big_slice, n].mean() for n in metric_axis},
                         **({f"DSC_source{n}": all_dices_s[big_slice, n].mean() for n in metric_axis}
                           if args.source_metrics else {})}

            size_dict = {**{f"SZ{n}": all_sizes[big_slice, n].mean() for n in metric_axis}}
            nice_dict = {k: f"{v:.4f}" for (k, v) in stat_dict.items()}
            done += B
            tq_iter.set_postfix(nice_dict)
    if args.dice_3d and (mode == 'val'):
        dice_3d_log, dice_3d_sd_log,asd_3d_log, asd_3d_sd_log,hd_3d_log, hd_3d_sd_log = dice3d(all_grp, all_inter_card, all_card_gt, all_card_pred,all_pred,all_gt,all_pnames,metric_axis,args.pprint,args.do_hd,args.do_asd,best_dice3d_val)
    dice_2d = torch.index_select(all_dices, 1, indices).mean().cpu().numpy().item()
    target_vec = [dice_3d_log, dice_3d_sd_log,asd_3d_log, asd_3d_sd_log,hd_3d_log,hd_3d_sd_log,dice_2d]
    size_mean = torch.index_select(all_sizes2, 1, indices).mean(dim=0).cpu().numpy()
    size_gt_mean = torch.index_select(all_gt_sizes, 1, indices).mean(dim=0).cpu().numpy()
    mask_pos = torch.index_select(all_sizes2, 1, indices)!=0
    gt_pos = torch.index_select(all_gt_sizes, 1, indices)!=0
    size_mean_pos = torch.index_select(all_sizes2, 1, indices).sum(dim=0).cpu().numpy()/mask_pos.sum(dim=0).cpu().numpy()
    gt_size_mean_pos = torch.index_select(all_gt_sizes, 1, indices).sum(dim=0).cpu().numpy()/gt_pos.sum(dim=0).cpu().numpy()
    size_mean2 = torch.index_select(all_sizes2, 1, indices).mean(dim=0).cpu().numpy()
    losses_vec = [loss_se.mean().item(),loss_cons.mean().item(),loss_tot.mean().item(),size_mean.mean(),size_mean_pos.mean(),size_gt_mean.mean(),gt_size_mean_pos.mean()]
    if not epc%50:
        df_t = pd.DataFrame({
           "val_ids":all_pnames,
           "proposal_size":all_sizes2.cpu()})
        df_t.to_csv(Path(savedir,mode+str(epc)+"sizes.csv"), float_format="%.4f", index_label="epoch")
    return losses_vec, target_vec,source_vec
Exemplo n.º 15
0
def for_back_step_comb(optimizer, mode, source_image, target_image, gt_source, labels,
                       net, loss_fns,loss_weights, loss_fns_source, loss_weights_source, new_w, device, bounds,
                        model_D, optimizer_D, lambda_adv_target):
    source_label = 0
    target_label = 1

    # Reset gradients
    if optimizer:
        optimizer.zero_grad()

    optimizer_D.zero_grad()

    # don't accumulate grads in D
    for param in model_D.parameters():
        param.requires_grad = False

    # Forward
    with torch.set_grad_enabled(mode == "train"):
        #Forward
        pred_logits_source: Tensor = net(source_image)
        probs_source: Tensor = F.softmax(pred_logits_source, dim=1)

        pred_logits_target: Tensor = net(target_image)
        probs_target: Tensor = F.softmax(pred_logits_target, dim=1)
        predicted_mask_target: Tensor = probs2one_hot(probs_target)

        if new_w > 0:
            probs_source = resize(probs_source, new_w)
            probs_target = resize(probs_target, new_w)
            if labels[0].shape[3]!= new_w:
                labels = [resize(label, new_w) for label in labels ]
            gt_source = resize(gt_source, new_w)

    assert len(bounds) == len(loss_fns) == len(loss_weights)

    loss_adv_target = torch.zeros(1, requires_grad=True).to(device)

    loss_vec = []

    # Losses on source
    ziped = zip(loss_fns_source, [gt_source], loss_weights_source)
    losses = [w * loss_fn(probs_source, label, torch.randn(1)) for loss_fn, label, w in ziped]
    loss_vec.extend([loss.item() for loss in losses])
    loss_source = reduce(add, losses)

    # add adversarial loss of target im if not a pair of negative images
    if lambda_adv_target > 0 and max(predicted_mask_target[0,1,...].sum(),predicted_mask_target[0,1,...].sum()).item()>0:
        D_out = model_D(probs_target)
        loss_adv_target = d_loss_calc(D_out, source_label).to(device=device)*lambda_adv_target
        loss_vec.append(loss_adv_target.item())
    else:
        loss_vec.append(0)

    # Constraint loss on target images (and eventually also cross entropy with FGT)
    ziped = zip(loss_fns, labels, loss_weights, bounds)
    losses = [w * loss_fn(probs_target, label, bound) for loss_fn, label, w, bound in ziped]
    loss_vec.extend([loss.item() for loss in losses])
    loss_target = reduce(add, losses)

    loss = loss_source + loss_adv_target + loss_target

    # Backward
    if optimizer:
        loss.backward()
        optimizer.step()

    # train D
    if lambda_adv_target > 0 and max(predicted_mask_target[0,1,...].sum(),predicted_mask_target[0,1,...].sum()).item()>0:
        # bring back requires_grad
        for param in model_D.parameters():
            param.requires_grad = True

        # train with source
        probs_source = probs_source.detach()
        D_out = model_D(probs_source)
        loss_D_s = d_loss_calc(D_out, source_label).to(device=device)/2
        if optimizer:
            loss_D_s.backward()

        # train with target
        probs_target = probs_target.detach()
        D_out_t = model_D(probs_target)
        loss_D_t = d_loss_calc(D_out_t, target_label).to(device=device)/2
        if optimizer:
            loss_D_t.backward()
            optimizer_D.step()
        loss_vec.append(loss_D_s.item()+loss_D_t.item())
    else:
        loss_vec.append(0)

    return probs_source, probs_target, loss, loss_vec[0], loss_vec[1], loss_vec[2], loss_vec[3], loss_vec[4]
Exemplo n.º 16
0
def do_epoch(
        mode: str,
        net: Any,
        device: Any,
        loaders: list[DataLoader],
        epc: int,
        list_loss_fns: list[list[Callable]],
        list_loss_weights: list[list[float]],
        K: int,
        savedir: str = "",
        optimizer: Any = None,
        metric_axis: list[int] = [1],
        compute_3d_dice: bool = False,
        temperature: float = 1) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
    assert mode in ["train", "val", "dual"]

    if mode == "train":
        net.train()
        desc = f">> Training   ({epc})"
    elif mode == "val":
        net.eval()
        desc = f">> Validation ({epc})"

    total_iteration: int = sum(len(loader) for loader in loaders)  # U
    total_images: int = sum(len(loader.dataset) for loader in loaders)  # D
    n_loss: int = max(map(len, list_loss_fns))

    all_dices: Tensor = torch.zeros((total_images, K),
                                    dtype=torch.float32,
                                    device=device)
    loss_log: Tensor = torch.zeros((total_iteration, n_loss),
                                   dtype=torch.float32,
                                   device=device)

    three_d_dices: Optional[Tensor]
    if compute_3d_dice:
        three_d_dices = torch.zeros((total_iteration, K),
                                    dtype=torch.float32,
                                    device=device)
    else:
        three_d_dices = None

    done_img: int = 0
    done_batch: int = 0
    tq_iter = tqdm_(total=total_iteration, desc=desc)
    for i, (loader, loss_fns, loss_weights) in enumerate(
            zip(loaders, list_loss_fns, list_loss_weights)):
        for data in loader:
            # t0 = time()
            image: Tensor = data["images"].to(device)
            target: Tensor = data["gt"].to(device)
            filenames: list[str] = data["filenames"]
            assert not target.requires_grad
            labels: list[Tensor] = [e.to(device) for e in data["labels"]]
            B, C, *_ = image.shape

            # Reset gradients
            if optimizer:
                optimizer.zero_grad()

            # Forward
            pred_logits: Tensor = net(image)
            pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1)
            predicted_mask: Tensor = probs2one_hot(
                pred_probs.detach())  # Used only for dice computation
            assert not predicted_mask.requires_grad

            assert len(loss_fns) == len(loss_weights) == len(labels)
            ziped = zip(loss_fns, labels, loss_weights)
            losses = [
                w * loss_fn(pred_probs, label) for loss_fn, label, w in ziped
            ]
            loss = reduce(add, losses)
            assert loss.shape == (), loss.shape

            # Backward
            if optimizer:
                loss.backward()
                optimizer.step()

            # Compute and log metrics
            for j in range(len(loss_fns)):
                loss_log[done_batch, j] = losses[j].detach()

            sm_slice = slice(done_img,
                             done_img + B)  # Values only for current batch

            dices: Tensor = dice_coef(predicted_mask, target)
            assert dices.shape == (B, K), (dices.shape, B, K)
            all_dices[sm_slice, ...] = dices

            if compute_3d_dice:
                three_d_DSC: Tensor = dice_batch(predicted_mask, target)
                assert three_d_DSC.shape == (K, )

                three_d_dices[done_batch] = three_d_DSC  # type: ignore

            # Save images
            if savedir:
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=UserWarning)
                    predicted_class: Tensor = probs2class(pred_probs)
                    save_images(predicted_class, filenames, savedir, mode, epc)

            # Logging
            big_slice = slice(0, done_img +
                              B)  # Value for current and previous batches

            dsc_dict: dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} | \
                ({f"3d_DSC{n}": three_d_dices[:done_batch, n].mean() for n in metric_axis}
                 if three_d_dices is not None else {})

            loss_dict = {
                f"loss_{i}": loss_log[:done_batch].mean(dim=0)[i]
                for i in range(n_loss)
            }

            stat_dict = dsc_dict | loss_dict
            nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()}

            done_img += B
            done_batch += 1
            tq_iter.set_postfix({**nice_dict, "loader": str(i)})
            tq_iter.update(1)
    tq_iter.close()

    print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))

    return (loss_log.detach().cpu(), all_dices.detach().cpu(),
            three_d_dices.detach().cpu()
            if three_d_dices is not None else None)
Exemplo n.º 17
0
fold_all_H2.write(f"file, dice, haussdorf,connecterror \n")
fold_clean_H2.write(f"file, dice, haussdorf,connecterror \n")


for _,_,files in os.walk(os.path.join(root, 'in_npy')): 
    for file in files: 
        image = np.load(os.path.join(root,'in_npy', file))
        gt = np.load(os.path.join(root,'gt_npy', file))
        
        #print('infering {} of shape {} and classes {}, max {} and min {} '.format( file, image.shape, np.unique(gt), image.max(), image.min()))
        image = image.reshape(-1, 1, 256, 256)
        image = torch.tensor(image, dtype=torch.float)
        image = Variable(image, requires_grad=True)
        pred = net(image)
        pred = F.softmax(pred, dim=1)
        predicted_output = probs2one_hot(pred.detach())
        #np.save(os.path.join(path, 'predictions', '{}'.format(file)), pred.detach().numpy())
        dice = dice_coef(predicted_output, class2one_hot(torch.tensor(gt), n_classes))
        hauss = haussdorf(predicted_output, class2one_hot(torch.tensor(gt), n_classes))
        
        #pred_label = len(np.unique(label(np.array(pred.argmax(axis = 1).detach().numpy()))))
        #gt_label = len(np.unique(label(gt)))
        pred_label = len(np.unique(label(predicted_output[0][1])))
        gt_label = len(np.unique(label(class2one_hot(torch.tensor(gt), n_classes)[0][1])))
        error = np.abs(pred_label - gt_label)
        pred_label2 = len(np.unique(label(predicted_output[0][2])))
        gt_label2 = len(np.unique(label(class2one_hot(torch.tensor(gt), n_classes)[0][2])))
        error2 = np.abs(pred_label2 - gt_label2)
        
        print(f"{file}, {np.float(dice[0][1])}, {np.float(hauss[0][1])},{np.float(error)} \n")
        print(f"{file}, {np.float(dice[0][2])}, {np.float(hauss[0][2])},{np.float(error2)} \n")
Exemplo n.º 18
0
def do_epoch(mode: str, net: Any, device: Any, loaders: list[DataLoader], epc: int,
             list_loss_fns: list[list[Callable]], list_loss_weights: list[list[float]], K: int,
             savedir: Path = None, optimizer: Any = None,
             metric_axis: list[int] = [1], requested_metrics: list[str] = None,
             temperature: float = 1) -> dict[str, Tensor]:
        assert mode in ["train", "val", "dual"]
        if requested_metrics is None:
                requested_metrics = []

        if mode == "train":
                net.train()
                desc = f">> Training   ({epc})"
        elif mode == "val":
                net.eval()
                desc = f">> Validation ({epc})"
        elif mode == "dual":
                net.eval()
                desc = f">> Dual       ({epc})"

        total_iteration: int = sum(len(loader) for loader in loaders)  # U
        total_images: int = sum(len(loader.dataset) for loader in loaders)  # D
        n_loss: int = max(map(len, list_loss_fns))

        epoch_metrics: dict[str, Tensor]
        epoch_metrics = {"dice": torch.zeros((total_images, K), dtype=torch.float32, device=device),
                         "loss": torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device)}

        if "3d_dsc" in requested_metrics:
                epoch_metrics["3d_dsc"] = torch.zeros((total_iteration, K), dtype=torch.float32, device=device)

        few_axis: bool = len(metric_axis) <= 4

        # time_log: np.ndarray = np.ndarray(total_iteration, dtype=np.float32)

        done_img: int = 0
        done_batch: int = 0
        tq_iter = tqdm_(total=total_iteration, desc=desc)
        for i, (loader, loss_fns, loss_weights) in enumerate(zip(loaders, list_loss_fns, list_loss_weights)):
                for data in loader:
                        # t0 = time()
                        image: Tensor = data["images"].to(device)
                        target: Tensor = data["gt"].to(device)
                        filenames: list[str] = data["filenames"]
                        assert not target.requires_grad

                        labels: list[Tensor] = [e.to(device) for e in data["labels"]]
                        bounds: list[Tensor] = [e.to(device) for e in data["bounds"]]
                        assert len(labels) == len(bounds)

                        B, C, *_ = image.shape

                        samplings: list[list[Tuple[slice]]] = data["samplings"]
                        assert len(samplings) == B
                        assert len(samplings[0][0]) == len(image[0, 0].shape), (samplings[0][0], image[0, 0].shape)

                        probs_receptacle: Tensor = - torch.ones_like(target, dtype=torch.float32)  # -1 for unfilled
                        mask_receptacle: Tensor = - torch.ones_like(target, dtype=torch.int32)  # -1 for unfilled

                        # Use the sampling coordinates of the first batch item
                        assert not (len(samplings[0]) > 1 and B > 1), samplings  # No subsampling if batch size > 1
                        loss_sub_log: Tensor = torch.zeros((len(samplings[0]), len(loss_fns)), 
                                                           dtype=torch.float32, device=device)
                        for k, sampling in enumerate(samplings[0]):
                                img_sampling = [slice(0, B), slice(0, C)] + list(sampling)
                                label_sampling = [slice(0, B), slice(0, K)] + list(sampling)
                                assert len(img_sampling) == len(image.shape), (img_sampling, image.shape)
                                sub_img = image[img_sampling]

                                # Reset gradients
                                if optimizer:
                                        optimizer.zero_grad()

                                # Forward
                                pred_logits: Tensor = net(sub_img)
                                pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1)

                                # Used only for dice computation:
                                predicted_mask: Tensor = probs2one_hot(pred_probs.detach())  
                                assert not predicted_mask.requires_grad

                                probs_receptacle[label_sampling] = pred_probs[...]
                                mask_receptacle[label_sampling] = predicted_mask[...]

                                assert len(bounds) == len(loss_fns) == len(loss_weights) == len(labels)
                                ziped = zip(loss_fns, labels, loss_weights, bounds)
                                losses = [w * loss_fn(pred_probs, label[label_sampling], bound, filenames)
                                          for loss_fn, label, w, bound in ziped]
                                loss = reduce(add, losses)
                                assert loss.shape == (), loss.shape

                                # Backward
                                if optimizer:
                                        loss.backward()
                                        optimizer.step()

                                # Compute and log metrics
                                for j in range(len(loss_fns)):
                                        loss_sub_log[k, j] = losses[j].detach()
                        reduced_loss_sublog: Tensor = loss_sub_log.sum(dim=0)
                        assert reduced_loss_sublog.shape == (len(loss_fns),), (reduced_loss_sublog.shape, len(loss_fns))
                        epoch_metrics["loss"][done_batch, ...] = reduced_loss_sublog[...]
                        del loss_sub_log

                        sm_slice = slice(done_img, done_img + B)  # Values only for current batch

                        dices: Tensor = dice_coef(mask_receptacle, target)
                        assert dices.shape == (B, K), (dices.shape, B, K)
                        epoch_metrics["dice"][sm_slice, ...] = dices

                        if "3d_dsc" in requested_metrics:
                                three_d_DSC: Tensor = dice_batch(mask_receptacle, target)
                                assert three_d_DSC.shape == (K,)

                                epoch_metrics["3d_dsc"][done_batch] = three_d_DSC  # type: ignore

                        # Save images
                        if savedir:
                                with warnings.catch_warnings():
                                        warnings.filterwarnings("ignore", category=UserWarning)
                                        predicted_class: Tensor = probs2class(pred_probs)
                                        save_images(predicted_class, 
                                                    data["filenames"], 
                                                    savedir / f"iter{epc:03d}" / mode)

                        # Logging
                        big_slice = slice(0, done_img + B)  # Value for current and previous batches

                        stat_dict: dict[str, Any] = {}
                        # The order matters for the final display -- it is easy to change

                        if few_axis:
                                stat_dict |= {f"DSC{n}": epoch_metrics["dice"][big_slice, n].mean()
                                              for n in metric_axis}

                                if "3d_dsc" in requested_metrics:
                                        stat_dict |= {f"3d_DSC{n}": epoch_metrics["3d_dsc"][:done_batch, n].mean()
                                                      for n in metric_axis}

                        if len(metric_axis) > 1:
                                stat_dict |= {"DSC": epoch_metrics["dice"][big_slice, metric_axis].mean()}

                        stat_dict |= {f"loss_{i}": epoch_metrics["loss"][:done_batch].mean(dim=0)[i] 
                                      for i in range(n_loss)}

                        nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()}

                        # t1 = time()
                        # time_log[done_batch] = (t1 - t0)

                        done_img += B
                        done_batch += 1
                        tq_iter.set_postfix({**nice_dict, "loader": str(i)})
                        tq_iter.update(1)
        tq_iter.close()

        print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items()))

        return {k: v.detach().cpu() for (k, v) in epoch_metrics.items()}