Пример #1
0
def val(model, val_loader):
    model.eval()
    val_loss = 0
    val_dice0 = 0
    val_dice1 = 0
    val_dice2 = 0
    with torch.no_grad():
        for idx, (data, target) in tqdm(enumerate(val_loader),
                                        total=len(val_loader)):
            target = common.to_one_hot_3d(target.long())
            data, target = data.float(), target.float()
            data, target = data.to(device), target.to(device)
            output = model(data)

            loss = metrics.DiceMeanLoss()(output, target)
            dice0 = metrics.dice(output, target, 0)
            dice1 = metrics.dice(output, target, 1)
            dice2 = metrics.dice(output, target, 2)

            val_loss += float(loss)
            val_dice0 += float(dice0)
            val_dice1 += float(dice1)
            val_dice2 += float(dice2)

    val_loss /= len(val_loader)
    val_dice0 /= len(val_loader)
    val_dice1 /= len(val_loader)
    val_dice2 /= len(val_loader)

    return OrderedDict({
        'Val Loss': val_loss,
        'Val dice0': val_dice0,
        'Val dice1': val_dice1,
        'Val dice2': val_dice2
    })
Пример #2
0
def val(model, val_loader, epoch, logger):
    model.eval()
    val_loss = 0
    val_dice0 = 0
    val_dice1 = 0
    val_dice2 = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.float(), target.float()
            data, target = data.to(device), target.to(device)
            output = model(data)

            loss = metrics.DiceMeanLoss()(output, target)
            dice0 = metrics.dice(output, target, 0)
            dice1 = metrics.dice(output, target, 1)
            dice2 = metrics.dice(output, target, 2)

            val_loss += float(loss)
            val_dice0 += float(dice0)
            val_dice1 += float(dice1)
            val_dice2 += float(dice2)

    val_loss /= len(val_loader)
    val_dice0 /= len(val_loader)
    val_dice1 /= len(val_loader)
    val_dice2 /= len(val_loader)

    logger.scalar_summary('val_loss', val_loss, epoch)
    logger.scalar_summary('val_dice0', val_dice0, epoch)
    logger.scalar_summary('val_dice1', val_dice1, epoch)
    logger.scalar_summary('val_dice2', val_dice2, epoch)
    print('\nVal set: Average loss: {:.6f}, dice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\t\n'.format(
        val_loss, val_dice0, val_dice1, val_dice2))
Пример #3
0
def test(model, test_loader):
    print("Evaluation of Testset Starting...")
    model.eval()
    val_loss = 0
    val_dice0 = 0
    val_dice1 = 0
    val_dice2 = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader):
            data, target = data.float(), target.float()
            data, target = data.to(device), target.to(device)
            output = model(data)

            loss = metrics.DiceMeanLoss()(output, target)
            dice0 = metrics.dice(output, target, 0)
            dice1 = metrics.dice(output, target, 1)
            dice2 = metrics.dice(output, target, 2)

            val_loss += float(loss)
            val_dice0 += float(dice0)
            val_dice1 += float(dice1)
            val_dice2 += float(dice2)

    val_loss /= len(test_loader)
    val_dice0 /= len(test_loader)
    val_dice1 /= len(test_loader)
    val_dice2 /= len(test_loader)

    print('\nTest set: Average loss: {:.6f}, dice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\t\n'.format(
        val_loss, val_dice0, val_dice1, val_dice2))
Пример #4
0
def cal_perfer(preds, masks, tb_dict):
    LV_dice = []  # 1
    MYO_dice = []  # 2
    RV_dice = []  # 3
    LV_hausdorff = []
    MYO_hausdorff = []
    RV_hausdorff = []

    for i in range(preds.shape[0]):
        LV_dice.append(dice(preds[i, 1, :, :], masks[i, 1, :, :]))
        RV_dice.append(dice(preds[i, 3, :, :], masks[i, 3, :, :]))
        MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :]))

        LV_hausdorff.append(
            cal_hausdorff_distance(preds[i, 1, :, :], masks[i, 1, :, :]))
        RV_hausdorff.append(
            cal_hausdorff_distance(preds[i, 3, :, :], masks[i, 3, :, :]))
        MYO_hausdorff.append(
            cal_hausdorff_distance(preds[i, 2, :, :], masks[i, 2, :, :]))

    tb_dict.update({"LV_dice": np.mean(LV_dice)})
    tb_dict.update({"RV_dice": np.mean(RV_dice)})
    tb_dict.update({"MYO_dice": np.mean(MYO_dice)})
    tb_dict.update({"LV_hausdorff": np.mean(LV_hausdorff)})
    tb_dict.update({"RV_hausdorff": np.mean(RV_hausdorff)})
    tb_dict.update({"MYO_hausdorff": np.mean(MYO_hausdorff)})
Пример #5
0
def test(model, dataset, save_path, filename):
    dataloader = DataLoader(dataset=dataset,
                            batch_size=4,
                            num_workers=0,
                            shuffle=False)
    model.eval()
    save_tool = Recompone_tool(save_path, filename, dataset.ori_shape,
                               dataset.new_shape, dataset.cut)
    target = torch.from_numpy(np.expand_dims(dataset.label_np, axis=0)).long()
    target = to_one_hot_3d(target)
    with torch.no_grad():
        for data in tqdm(dataloader, total=len(dataloader)):
            data = data.unsqueeze(1)
            data = data.float().to(device)
            output = model(data)
            save_tool.add_result(output.detach().cpu())

    pred = save_tool.recompone_overlap()
    pred = torch.unsqueeze(pred, dim=0)
    val_loss = metrics.DiceMeanLoss()(pred, target)
    val_dice0 = metrics.dice(pred, target, 0)
    val_dice1 = metrics.dice(pred, target, 1)
    val_dice2 = metrics.dice(pred, target, 2)

    pred_img = torch.argmax(pred, dim=1)
    img = sitk.GetImageFromArray(
        np.squeeze(np.array(pred_img.numpy(), dtype='uint8'), axis=0))
    sitk.WriteImage(img, os.path.join(save_path, filename))

    # save_tool.save(filename)
    print(
        '\nAverage loss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}\t\n'
        .format(val_loss, val_dice0, val_dice1, val_dice2))
    return val_loss, val_dice0, val_dice1, val_dice2
Пример #6
0
 def metric(self, logit, truth, nonempty_only=False, logit_clf=None):
     """Define metrics for evaluation especially for early stoppping."""
     #return iou_pytorch(logit, truth)
     return dice(logit,
                 truth,
                 nonempty_only=nonempty_only,
                 logit_clf=logit_clf)
Пример #7
0
def cal_perfer(preds, masks, tb_dict):
    LV_dice = []  # 1
    MYO_dice = []  # 2
    RV_dice = []  # 3

    for i in range(preds.shape[0]):
        LV_dice.append(dice(preds[i, 1, :, :], masks[i, 1, :, :]))
        RV_dice.append(dice(preds[i, 3, :, :], masks[i, 3, :, :]))
        MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :]))
        # LV_dice.append(dice(preds[i, 3,:,:],masks[i,1,:,:]))
        # RV_dice.append(dice(preds[i, 1, :, :], masks[i, 3, :, :]))
        # MYO_dice.append(dice(preds[i, 2, :, :], masks[i, 2, :, :]))

    tb_dict["LV_dice"].append(np.mean(LV_dice))
    tb_dict["RV_dice"].append(np.mean(RV_dice))
    tb_dict["MYO_dice"].append(np.mean(MYO_dice))
    return np.mean(LV_dice), np.mean(RV_dice), np.mean(MYO_dice)
Пример #8
0
def train(model, train_loader, optimizer, epoch, logger):
    print("=======Epoch:{}=======".format(epoch))
    model.train()
    train_loss = 0
    train_dice0 = 0
    train_dice1 = 0
    train_dice2 = 0
    for idx, (data, target) in tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
        data = torch.squeeze(data, dim=0)
        target = torch.squeeze(target, dim=0)
        data, target = data.float(), target.float()
        data, target = data.to(device), target.to(device)
        output = model(data)

        optimizer.zero_grad()

        # loss = nn.CrossEntropyLoss()(output,target)
        # loss=metrics.SoftDiceLoss()(output,target)
        # loss=nn.MSELoss()(output,target)
        loss = metrics.DiceMeanLoss()(output, target)
        # loss=metrics.WeightDiceLoss()(output,target)
        # loss=metrics.CrossEntropy()(output,target)
        loss.backward()
        optimizer.step()

        train_loss += loss
        train_dice0 += metrics.dice(output, target, 0)
        train_dice1 += metrics.dice(output, target, 1)
        train_dice2 += metrics.dice(output, target, 2)
    train_loss /= len(train_loader)
    train_dice0 /= len(train_loader)
    train_dice1 /= len(train_loader)
    train_dice2 /= len(train_loader)

    print(
        'Train Epoch: {} \tLoss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}'
        .format(epoch, train_loss, train_dice0, train_dice1, train_dice2))

    logger.scalar_summary('train_loss', float(train_loss), epoch)
    logger.scalar_summary('train_dice0', float(train_dice0), epoch)
    logger.scalar_summary('train_dice1', float(train_dice1), epoch)
    logger.scalar_summary('train_dice2', float(train_dice2), epoch)
Пример #9
0
def train(model, train_loader):
    print("=======Epoch:{}=======".format(epoch))
    model.train()
    train_loss = 0
    train_dice0 = 0
    train_dice1 = 0
    train_dice2 = 0
    for idx, (data, target) in tqdm(enumerate(train_loader),
                                    total=len(train_loader)):
        data = torch.squeeze(data, dim=0)
        target = torch.squeeze(target, dim=0)
        data, target = data.float(), target.float()
        data, target = data.to(device), target.to(device)
        output = model(data)

        optimizer.zero_grad()

        # loss = nn.CrossEntropyLoss()(output,target)
        # loss=metrics.SoftDiceLoss()(output,target)
        # loss=nn.MSELoss()(output,target)
        loss = metrics.DiceMeanLoss()(output, target)
        # loss=metrics.WeightDiceLoss()(output,target)
        # loss=metrics.CrossEntropy()(output,target)
        loss.backward()
        optimizer.step()

        train_loss += float(loss)
        train_dice0 += float(metrics.dice(output, target, 0))
        train_dice1 += float(metrics.dice(output, target, 1))
        train_dice2 += float(metrics.dice(output, target, 2))
    train_loss /= len(train_loader)
    train_dice0 /= len(train_loader)
    train_dice1 /= len(train_loader)
    train_dice2 /= len(train_loader)

    return OrderedDict({
        'Train Loss': train_loss,
        'Train dice0': train_dice0,
        'Train dice1': train_dice1,
        'Train dice2': train_dice2
    })
Пример #10
0
def train(model, train_loader, optimizer, epoch, logger):
    model.train()
    train_loss = 0
    train_dice0 = 0
    train_dice1 = 0
    train_dice2 = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data = torch.squeeze(data, dim=0)
        target = torch.squeeze(target, dim=0)
        data, target = data.float(), target.float()
        data, target = data.to(device), target.to(device)
        output = model(data)

        optimizer.zero_grad()

        # loss = nn.CrossEntropyLoss()(output,target)
        # loss=metrics.SoftDiceLoss()(output,target)
        # loss=nn.MSELoss()(output,target)
        loss = metrics.DiceMeanLoss()(output, target)
        # loss=metrics.WeightDiceLoss()(output,target)
        # loss=metrics.CrossEntropy()(output,target)
        loss.backward()
        optimizer.step()

        train_loss = loss
        train_dice0 = metrics.dice(output, target, 0)
        train_dice1 = metrics.dice(output, target, 1)
        train_dice2 = metrics.dice(output, target, 2)
        print(
            'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tdice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\tT: {:.6f}\tP: {:.6f}\tTP: {:.6f}'
            .format(epoch, batch_idx,
                    len(train_loader), 100. * batch_idx / len(train_loader),
                    loss.item(), train_dice0, train_dice1, train_dice2,
                    metrics.T(output, target), metrics.P(output, target),
                    metrics.TP(output, target)))

    logger.scalar_summary('train_loss', float(train_loss), epoch)
    logger.scalar_summary('train_dice0', float(train_dice0), epoch)
    logger.scalar_summary('train_dice1', float(train_dice1), epoch)
    logger.scalar_summary('train_dice2', float(train_dice2), epoch)
Пример #11
0
def validate(valid_dataset, model_multi_views, save_dir):
    label_list = [0, 1]

    valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

    final_score = []

    for valid_index, valid_batch_sample in enumerate(valid_dataloader):

        volume_input_views = []
        for i, sample_batch in enumerate(valid_batch_sample):
            volume_input_views.append(sample_batch['volume'].to("cuda:%d" %
                                                                i).float())
        valid_seg = valid_batch_sample[0]['label'].to("cuda:0").float()

        valid_name = valid_batch_sample[0]['name']

        # image_shape = (valid_seg.shape[2], valid_seg.shape[3], valid_seg.shape[4])

        # Increase the dimension
        # valid_input = torch.unsqueeze(valid_input, 1)
        # valid_seg = valid_seg.cpu().detach().numpy()
        # valid_seg = np.squeeze(valid_seg, 1)

        pred_seg = inference(model_multi_views, volume_input_views)
        pred_seg = pred_seg.cpu().detach().numpy()
        valid_seg = valid_seg.cpu().detach().numpy()
        valid_seg = np.squeeze(valid_seg, axis=1)

        # calculate dice
        dice_dict, average_dice = metrics.dice(pred_seg, valid_seg, label_list)
        print("validating %2d/%d, valid_name: %-10s, final score: %f" %
              (valid_index + 1, len(valid_dataset), valid_name, average_dice))
        # print(dice_dict)

        fuse_seg = util.fuse_pred_label(valid_seg, pred_seg)

        save_pred_name = 'pred' + valid_name[0].split('img')[-1] + '.nii.gz'
        save_fuse_name = 'fuse' + valid_name[0].split('img')[-1] + '.nii.gz'
        # save_dir = '../data/UNet3D_pred/' + arg.model_path.split('UNet3D/')[-1] + '/20500/view2/'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_volume(pred_seg, save_dir, save_pred_name)
        save_volume(fuse_seg, save_dir, save_fuse_name)

        final_score.append(average_dice)

    valid_average_score = np.average(final_score)
    print("validating done")
    print('validation score :%f' % valid_average_score)

    return
Пример #12
0
def val(model, val_loader, epoch, logger):
    model.eval()
    val_loss = 0
    val_dice0 = 0
    val_dice1 = 0
    val_dice2 = 0
    with torch.no_grad():
        for idx, (data, target) in tqdm(enumerate(val_loader),
                                        total=len(val_loader)):
            data = torch.squeeze(data, dim=0)
            target = torch.squeeze(target, dim=0)
            data, target = data.float(), target.float()
            data, target = data.to(device), target.to(device)
            output = model(data)

            loss = metrics.DiceMeanLoss()(output, target)
            dice0 = metrics.dice(output, target, 0)
            dice1 = metrics.dice(output, target, 1)
            dice2 = metrics.dice(output, target, 2)

            val_loss += float(loss)
            val_dice0 += float(dice0)
            val_dice1 += float(dice1)
            val_dice2 += float(dice2)

    val_loss /= len(val_loader)
    val_dice0 /= len(val_loader)
    val_dice1 /= len(val_loader)
    val_dice2 /= len(val_loader)

    logger.scalar_summary('val_loss', val_loss, epoch)
    logger.scalar_summary('val_dice0', val_dice0, epoch)
    logger.scalar_summary('val_dice1', val_dice1, epoch)
    logger.scalar_summary('val_dice2', val_dice2, epoch)
    print(
        'Val performance: Average loss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}\t\n'
        .format(val_loss, val_dice0, val_dice1, val_dice2))
Пример #13
0
    def test(self):
        # ~~~~~~~~~~~~~~~~~ Test for segmentation instead of registration ~~~~~~~~~~~~~~~~~
        self.netReg.eval()

        with torch.no_grad():
            self.input_moving = Variable(self.moving).cuda()
            self.input_fixed = Variable(self.fixed).cuda()
            self.input_moving_atlas = Variable(self.moving_atlas).cuda()
            self.input_fixed_atlas = Variable(self.fixed_atlas).cuda()

            _, segmentation_result, flow, _, _, _, _ = self.netReg(
                self.input_fixed, self.input_moving, self.input_fixed_atlas)

            # ~~~~~~~~~~~~~~~~~ Evaluation. ~~~~~~~~~~~~~~~~~
            self.metric_mean_dcs = 0

            # Calculate the Dice coefficient
            dsc, volume = metrics.dice(
                segmentation_result.data.int().cpu().numpy(),
                self.moving_atlas.cpu().int().numpy())
            print('Total_Mask_Dice={:0.4f}. '.format(np.nanmean(dsc)))

            if not os.path.exists(
                    str('./checkpoints/{}/output_{}'.format(
                        self.folder_names, self.folder_names))):
                os.makedirs(
                    str('./checkpoints/{}/output_{}'.format(
                        self.folder_names, self.folder_names)))

            sitk.WriteImage(
                sitk.GetImageFromArray(
                    segmentation_result.data.int().cpu().numpy().squeeze()),
                str('/home/ll610/Onepiece/code/project/GANRegSeg/checkpoints/{}/output_{}/seg_of_moving_{}'
                    .format(self.folder_names, self.folder_names,
                            self.moving_paths[0].split('/')[-1])))

            sitk.WriteImage(
                sitk.GetImageFromArray(
                    np.transpose(flow.data.float().squeeze().cpu().numpy(),
                                 (1, 2, 3, 0))),
                str('/home/ll610/Onepiece/code/project/GANRegSeg/checkpoints/{}/output_{}/deformation_field_{}'
                    .format(self.folder_names, self.folder_names,
                            self.moving_paths[0].split('/')[-1])))

            self.metric_mean_dcs = np.nanmean(dsc)

            return self.metric_mean_dcs, dsc
 def metric(self, logit_mask, truth_mask, logit_clf, truth_clf):
     """Define metrics for evaluation especially for early stoppping."""
     #return iou_pytorch(logit, truth)
     #return dice_multitask(logit_mask, truth_mask, logit_clf, truth_clf, iou=False, eps=1e-8)
     return dice(logit, truth)
 def metric(self, logit, truth):
     """Define metrics for evaluation especially for early stoppping."""
     #return iou_pytorch(logit, truth)
     return dice(logit, truth)
Пример #16
0
            loss_seg = ce_loss(
                outputs[:labeled_bs, 0, ...], label_batch[:labeled_bs].float())
            loss_seg_dice = losses.dice_loss(
                outputs_soft[:labeled_bs, 0, :, :, :], label_batch[:labeled_bs] == 1)
            dis_to_mask = torch.sigmoid(-1500*outputs_tanh)

            consistency_loss = F.mse_loss(dis_to_mask, outputs_soft)
            supervised_loss = loss_seg_dice + args.beta * loss_sdf

            loss = supervised_loss + 0.1 * consistency_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            dc = metrics.dice(torch.argmax(
                outputs_soft[:labeled_bs], dim=1), label_batch[:labeled_bs])

            iter_num = iter_num + 1
            writer.add_scalar('lr', lr_, iter_num)
            writer.add_scalar('loss/loss', loss, iter_num)
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
            writer.add_scalar('loss/loss_dice', loss_seg_dice, iter_num)
            writer.add_scalar('loss/loss_hausdorff', loss_sdf, iter_num)
            writer.add_scalar('loss/consistency_weight',
                              consistency_weight, iter_num)
            writer.add_scalar('loss/consistency_loss',
                              consistency_loss, iter_num)

            logging.info(
                'iteration %d : loss : %f, loss_consis: %f, loss_haus: %f, loss_seg: %f, loss_dice: %f' %
                (iter_num, loss.item(), consistency_loss.item(), loss_sdf.item(),
Пример #17
0
            # model loss
            model_dis_loss = loss_dis_dice + consistency_weight * consistency_loss
            model_seg_loss = loss_seg + consistency_weight * consistency_loss


            optimizer_seg.zero_grad()
            optimizer_dis.zero_grad()
            
            model_seg_loss.backward(retain_graph=True)
            model_dis_loss.backward()
                  
            optimizer_seg.step()           
            optimizer_dis.step()


            dc = metrics.dice(torch.argmax(
                softmask_seg[:labeled_bs], dim=1), label_batch[:labeled_bs])

            iter_num = iter_num + 1
            writer.add_scalar('lr', lr_, iter_num)
            writer.add_scalar('loss/loss_seg', model_seg_loss, iter_num)
            writer.add_scalar('loss/loss_dis', model_dis_loss, iter_num)
            writer.add_scalar('loss/loss_dice', loss_dice, iter_num)
            writer.add_scalar('loss/consistency_weight',
                              consistency_weight, iter_num)
            writer.add_scalar('loss/consistency_loss',
                              consistency_loss, iter_num)

            logging.info(
                'iteration %d : loss_seg : %f, loss_dis: %f, loss_consistency: %f, loss_dice: %f' %
                (iter_num, model_seg_loss.item(), model_dis_loss, 
                consistency_loss.item(), loss_dice.item()))