Example #1
0
def CalculateWTTCET(wtpbregion, wtmaskregion, tcpbregion, tcmaskregion,
                    etpbregion, etmaskregion):
    #开始计算WT
    dice = dice_coef(wtpbregion, wtmaskregion)
    wt_dices.append(dice)
    ppv_n = ppv(wtpbregion, wtmaskregion)
    wt_ppvs.append(ppv_n)
    Hausdorff = hausdorff_distance(wtmaskregion, wtpbregion)
    wt_Hausdorf.append(Hausdorff)
    sensitivity_n = sensitivity(wtpbregion, wtmaskregion)
    wt_sensitivities.append(sensitivity_n)
    # 开始计算TC
    dice = dice_coef(tcpbregion, tcmaskregion)
    tc_dices.append(dice)
    ppv_n = ppv(tcpbregion, tcmaskregion)
    tc_ppvs.append(ppv_n)
    Hausdorff = hausdorff_distance(tcmaskregion, tcpbregion)
    tc_Hausdorf.append(Hausdorff)
    sensitivity_n = sensitivity(tcpbregion, tcmaskregion)
    tc_sensitivities.append(sensitivity_n)
    # 开始计算ET
    dice = dice_coef(etpbregion, etmaskregion)
    et_dices.append(dice)
    ppv_n = ppv(etpbregion, etmaskregion)
    et_ppvs.append(ppv_n)
    Hausdorff = hausdorff_distance(etmaskregion, etpbregion)
    et_Hausdorf.append(Hausdorff)
    sensitivity_n = sensitivity(etpbregion, etmaskregion)
    et_sensitivities.append(sensitivity_n)
Example #2
0
    def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
        # # you need to extract dim [1] cuz input is like [n, 2, 300, 300] and target like [n, 300, 300]
        # # dice_coef takes [n, 300, 300] as input
        # if inputs.ndim == 4:
        #     inputs = inputs[:, 1]

        if self.reduction == 'mean':
            dice = torch.mean(1 - dice_coef(inputs, targets, self.epsilon))
        else:
            dice = torch.sum(1 - dice_coef(inputs, targets, self.epsilon))
        return dice
Example #3
0
def validate(config, val_loader, model, criterion):
    avg_meters = {
        'loss': AverageMeter(),
        'iou': AverageMeter(),
        'dice': AverageMeter()
    }

    # switch to evaluate mode
    model.eval()
    num_class = int(config['num_classes'])
    with torch.no_grad():
        pbar = tqdm(total=len(val_loader))
        for ori_img, input, target, targets, _ in val_loader:
            input = input.cuda()
            target = target.cuda()

            # compute output
            if config['deep_supervision']:
                outputs = model(input)
                loss = 0
                for output in outputs:
                    loss += criterion(output, target)
                loss /= len(outputs)
                iou = iou_score(outputs[-1], target)
                dice = dice_coef(outputs[-1], target)
            else:
                #input[torch.isnan(input)] = 0
                output = model(input)
                output[torch.isnan(output)] = 0
                out_m = output[:, 1:num_class, :, :].clone()
                tar_m = target[:, 1:num_class, :, :].clone()
                loss = criterion(output, target)
                #loss = criterion(out_m, tar_m)
                iou = iou_score(out_m, tar_m)
                dice = dice_coef(out_m, tar_m)
                # iou = iou_score(output, target)
                # dice = dice_coef(output, target)

            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['iou'].update(iou, input.size(0))
            avg_meters['dice'].update(dice, input.size(0))

            postfix = OrderedDict([
                ('loss', avg_meters['loss'].avg),
                ('iou', avg_meters['iou'].avg),
                ('dice', avg_meters['dice'].avg),
            ])
            pbar.set_postfix(postfix)
            pbar.update(1)
        pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg)])
Example #4
0
def train(config, train_loader, model, criterion, optimizer):
    avg_meters = {
        'loss': AverageMeter(),
        'iou': AverageMeter(),
        'dice': AverageMeter()
    }

    model.train()

    pbar = tqdm(total=len(train_loader))
    #print("length of dataloader from inside the function is " + str(len(train_loader)))
    #print(train_loader)
    for input, target, _ in train_loader:
        input = input.cuda()
        target = target.cuda()

        # compute output
        if config['deep_supervision']:
            outputs = model(input)
            loss = 0
            for output in outputs:
                loss += criterion(output, target)
            loss /= len(outputs)
            iou = iou_score(outputs[-1], target)
            dice = dice_coef(outputs[-1], target)
        else:
            output = model(input)
            loss = criterion(output, target)
            iou = iou_score(output, target)
            dice = dice_coef(output, target)

        # compute gradient and do optimizing step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))
        avg_meters['dice'].update(dice, input.size(0))

        postfix = OrderedDict([
            ('loss', avg_meters['loss'].avg),
            ('iou', avg_meters['iou'].avg),
            ('dice', avg_meters['dice'].avg),
        ])
        pbar.set_postfix(postfix)
        pbar.update(1)
    pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg)])
Example #5
0
def test(config, test_loader, model):
    avg_meters = {'iou': AverageMeter(), 'dice': AverageMeter()}

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        pbar = tqdm(total=len(test_loader))
        for input, target, meta in test_loader:
            #input = input.cuda()
            #target = target.cuda()

            # compute output
            if config['deep_supervision']:
                output = model(input)[-1]
            else:
                output = model(input)

            iou = iou_score(output, target)
            dice = dice_coef(output, target)

            avg_meters['iou'].update(iou, input.size(0))
            avg_meters['dice'].update(dice, input.size(0))

            postfix = OrderedDict([('iou', avg_meters['iou'].avg),
                                   ('dice', avg_meters['dice'].avg)])
            pbar.set_postfix(postfix)
            pbar.update(1)
        pbar.close()

    return OrderedDict([('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg)])
def validate(val_loader, model, criterion):
    avg_meters = {
        'loss': AverageMeter(),
        'iou': AverageMeter(),
        'dice': AverageMeter()
    }

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        pbar = tqdm(total=len(val_loader))
        for input, target in val_loader:
            input = input.cuda()
            target = target.cuda()

            output = model(input)
            loss = criterion(output, target)
            iou = iou_score(output, target)
            dice = dice_coef(output, target)

            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['iou'].update(iou, input.size(0))
            avg_meters['dice'].update(dice, input.size(0))

            postfix = OrderedDict([('loss', avg_meters['loss'].avg),
                                   ('iou', avg_meters['iou'].avg),
                                   ('dice', avg_meters['dice'].avg)])
            pbar.set_postfix(postfix)
            pbar.update(1)
        pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg)])
Example #7
0
 def training_step_end(self, batch_parts):
     logging.info(f"In training_step_end(): device={torch.cuda.current_device() if torch.cuda.is_available() else 0}")
     if type(batch_parts) is torch.Tensor: 
         return batch_parts.mean()
       
     x, y, y_hat = batch_parts
     loss = self.criterion(x, y_hat, y)
     iou = iou_score(y_hat, y)
     dice = dice_coef(y_hat, y)
     return loss
Example #8
0
def train(config, train_loader, model, criterion, optimizer):
    avg_meters = {'loss': AverageMeter(),
                  'iou': AverageMeter(),
                  'dice': AverageMeter()}

    model.train()

    pbar = tqdm(total=len(train_loader))
    for input, target, _ in train_loader:
        # input --> bz * channel(3) * h * w
        # target --> bz * 1 * h * w
        # print ('---', input.size())
        input = input.cuda()
        target = target.cuda()

        # compute output
        if config['deep_supervision']:
            outputs = model(input)
            loss = 0
            for output in outputs:
                loss += criterion(input, output, target)
            loss /= len(outputs)

            output = outputs[-1]
        else:
            output = model(input)

            loss = criterion(input, output, target)

        # compute gradient and do optimizing step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iou = iou_score(output, target)
        dice = dice_coef(output, target)

        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))
        avg_meters['dice'].update(dice, input.size(0))

        postfix = OrderedDict([
            ('loss', avg_meters['loss'].avg),
            ('iou', avg_meters['iou'].avg),
            ('dice', avg_meters['dice'].avg)
        ])
        pbar.set_postfix(postfix)
        pbar.update(1)
    pbar.close()

    return OrderedDict([
        ('loss', avg_meters['loss'].avg),
        ('iou', avg_meters['iou'].avg),
        ('dice', avg_meters['dice'].avg)
    ])
 def func(y_true, y_pred):
     from metrics import dice_coef
     d = dice_coef(y_true, y_pred)
     b = K.mean(K.binary_crossentropy(y_true, y_pred))
     loss = .5 * b - d
     loss = tf.Print(loss, [d], message='\nDC:\t')
     loss = tf.Print(loss, [b], message='CE:\t')
     loss = tf.Print(loss, [tf.shape(y_true)],
                     message='Shape:\t',
                     summarize=10)
     return loss
Example #10
0
def validate(args, val_loader, model, criterion, best_params):
    avg_losses = AverageMeter()
    avg_metrics = AverageMeter()

    # switch to evaluate mode
    model.eval()
    metric_criterion = metrics.__dict__[args.metric]().cuda()
    with torch.no_grad():
        for inputs, target in tqdm(val_loader):
            inputs = inputs.float().cuda()
            target = target.float().cuda()

            # compute output
            if args.deepsupervision:
                outputs = model(inputs)
                loss = 0
                for output in outputs:
                    loss += criterion(output, target)
                loss /= len(outputs)
                if best_params is None:
                    metric = metrics.dice_coef(outputs[-1], target)
                else:
                    metric, _ = metric_criterion(outputs[-1], target, best_params=best_params)

            else:
                outputs = model(inputs)
                loss = criterion(outputs, target)
                if best_params is None:
                    metric = metrics.dice_coef(outputs, target)
                else:
                    metric, _ = metric_criterion(outputs, target, best_params=best_params)

            avg_losses.update(loss.item(), inputs.size(0))
            avg_metrics.update(metric, inputs.size(0))

    log = OrderedDict([
        ('loss', avg_losses.avg),
        ('metric', avg_metrics.avg),
    ])

    return log
Example #11
0
def predict_and_evaluate(img_path, msk_path, model):
    img, msk_true = get_img_mask(img_path, msk_path)
    img_tensor, msk_tensor = preproc_to_model(img, msk_true)
    result = model.predict(img_tensor)
    print(result.shape)
    print()
    loss_and_dice = model.evaluate(img_tensor, msk_tensor)
    print(
        "\n======================================================================================\n"
    )
    print("Dice: {}\n".format(loss_and_dice[1]))
    dice_test = metrics.dice_coef(msk_tensor, msk_tensor)
    print("Метрика для ground true: {}".format(dice_test))
    msk_predicted = create_mask(result)
    msk_predicted = predToGrayImage(msk_predicted)
    visualise(img, msk_true, msk_predicted)
Example #12
0
def train(args,
          train_loader,
          model,
          criterion,
          optimizer,
          epoch,
          scheduler=None):
    losses = AverageMeter()
    ious = AverageMeter()
    dices = AverageMeter()
    model.train()

    for i, (input, target) in tqdm(enumerate(train_loader),
                                   total=len(train_loader)):
        input = input.cuda()
        target = target.cuda()

        # compute output
        if args.deepsupervision:
            outputs = model(input)
            loss = 0
            for output in outputs:
                loss += criterion(output, target)
            loss /= len(outputs)
            iou = iou_score(outputs[-1], target)
        else:
            output = model(input)
            loss = criterion(output, target)
            iou = iou_score(output, target)
            dice = dice_coef(output, target)

        losses.update(loss.item(), input.size(0))
        ious.update(iou, input.size(0))
        dices.update(dice, input.size(0))

        # compute gradient and do optimizing step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    log = OrderedDict([('loss', losses.avg), ('iou', ious.avg),
                       ('dice', dices.avg)])

    return log
def test(model, test_inputs, test_labels):
    """
    :param model: tf.keras.Model inherited data type
        model being trained  
    :param test_input: Numpy Array - shape (num_images, imsize, imsize, channels)
        input images to test on
    :param test_labels: Numpy Array - shape (num_images, 2)
        ground truth labels one-hot encoded
    :return: float, float, float, float 
        returns dice score, sensitivity value (0.5 threshold), specificity value (0.5 threshold), 
        and precision value all of which are in the range [0,1]
    """
    BATCH_SZ = model.batch_size
    indices = np.arange(test_inputs.shape[0]).tolist()
    all_logits = None
    for i in range(0, test_labels.shape[0], BATCH_SZ):
        images = test_inputs[indices[i:i + BATCH_SZ]]
        logits = model(images)
        if type(all_logits) == type(None):
            all_logits = logits
        else:
            all_logits = np.concatenate([all_logits, logits], axis=0)
    """this should break if the dataset size isnt divisible by the batch size because
    the for loop it runs the batches on doesnt get predictions for the remainder"""
    sensitivity_val1 = sensitivity(test_labels, all_logits, threshold=0.15)
    sensitivity_val2 = sensitivity(test_labels, all_logits, threshold=0.3)
    sensitivity_val3 = sensitivity(test_labels, all_logits, threshold=0.5)
    specificity_val1 = specificity(test_labels, all_logits, threshold=0.15)
    specificity_val2 = specificity(test_labels, all_logits, threshold=0.3)
    specificity_val3 = specificity(test_labels, all_logits, threshold=0.5)

    dice = dice_coef(test_labels, all_logits)
    precision_val = precision(test_labels, all_logits)
    print(
        "Sensitivity 0.15: {}, Senstivity 0.3: {}, Senstivity 0.5: {}".format(
            sensitivity_val1, sensitivity_val2, sensitivity_val3))
    print("Specificity 0.15: {}, Specificity 0.3: {}, Specificity 0.5: {}".
          format(specificity_val1, specificity_val2, specificity_val3))
    print("DICE: {}, Precision: {}".format(dice, precision_val))

    return dice.numpy(), sensitivity_val3, specificity_val3, precision_val
def train(model, generator, verbose=False):
    """trains the model for one epoch

    :param model: tf.keras.Model inherited data type
        model being trained 
    :param generator: BalancedDataGenerator
        a datagenerator which runs preprocessing and returns batches accessed
        by integers indexing (i.e. generator[0] returns the first batch of inputs 
        and labels)
    :param verbose: boolean
        whether to output the dice score every batch
    :return: list
        list of losses from every batch of training
    """
    BATCH_SZ = model.batch_size
    train_steps = generator.steps_per_epoch
    loss_list = []
    for i in range(0, train_steps, 1):
        images, labels = generator[i]
        with tf.GradientTape() as tape:
            logits = model(images)
            loss = model.loss_function(labels, logits)
        if i % 4 == 0 and verbose:
            sensitivity_val = sensitivity(labels, logits)
            specificity_val = specificity(labels, logits)
            precision_val = precision(labels, logits)
            train_dice = dice_coef(labels, logits)
            print("Scores on training batch after {} training steps".format(i))
            print("Sensitivity1: {}, Specificity: {}".format(
                sensitivity_val, specificity_val))
            print("Precision: {}, DICE: {}\n".format(precision_val,
                                                     train_dice))

        loss_list.append(loss)
        gradients = tape.gradient(loss, model.trainable_variables)
        model.optimizer.apply_gradients(
            zip(gradients, model.trainable_variables))

    return loss_list
Example #15
0
def test_per_class(config, test_loader, model):
    avg_meters = []
    for _ in range(config['num_classes']):
        avg_meters.append({'iou': AverageMeter(), 'dice': AverageMeter()})

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for input, target, _ in test_loader:
            #input = input.cuda()
            #target = target.cuda()

            # compute output
            if config['deep_supervision']:
                output = model(input)[-1]
            else:
                output = model(input)

            for class_id in range(output.shape[1]):
                output_per_class = torch.unsqueeze(output[:, class_id, :, :],
                                                   1)
                target_per_class = torch.unsqueeze(target[:, class_id, :, :],
                                                   1)

                iou = iou_score(output_per_class, target_per_class)
                dice = dice_coef(output_per_class, target_per_class)

                avg_meters[class_id]['iou'].update(iou, input.size(0))
                avg_meters[class_id]['dice'].update(dice, input.size(0))

    results = []
    for class_id in range(config['num_classes']):
        results.append(
            OrderedDict([('iou', avg_meters[class_id]['iou'].avg),
                         ('dice', avg_meters[class_id]['dice'].avg)]))

    return results
Example #16
0
def train(train_loader, model, criterion, optimizer):
    avg_meters = {
        'loss': AverageMeter(),
        'iou': AverageMeter(),
        'dice': AverageMeter()
    }

    model.train()

    pbar = tqdm(total=len(train_loader))
    for input, target in train_loader:
        input = input.cuda()
        target = target.cuda()

        output = model(input)
        loss = criterion(output, target)
        iou = iou_score(output, target)
        dice = dice_coef(output, target)

        # compute gradient and do optimizing step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))
        avg_meters['dice'].update(dice, input.size(0))

        postfix = OrderedDict([('loss', avg_meters['loss'].avg),
                               ('iou', avg_meters['iou'].avg),
                               ('dice', avg_meters['dice'].avg)])
        pbar.set_postfix(postfix)
        pbar.update(1)
    pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg)])
Example #17
0
def validate(args, val_loader, model, criterion):
    losses = AverageMeter()
    ious = AverageMeter()
    dices = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, (input, target) in tqdm(enumerate(val_loader),
                                       total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()

            # compute output
            if args.deepsupervision:
                outputs = model(input)
                loss = 0
                for output in outputs:
                    loss += criterion(output, target)
                loss /= len(outputs)
                iou = iou_score(outputs[-1], target)
            else:
                output = model(input)
                loss = criterion(output, target)
                iou = iou_score(output, target)
                dice = dice_coef(output, target)

            losses.update(loss.item(), input.size(0))
            ious.update(iou, input.size(0))
            dices.update(dice, input.size(0))

    log = OrderedDict([('loss', losses.avg), ('iou', ious.avg),
                       ('dice', dices.avg)])

    return log
Example #18
0
def main():
    val_args = parse_args()

    args = joblib.load('models/%s/args.pkl' % val_args.name)

    if not os.path.exists('output/%s' % args.name):
        os.makedirs('output/%s' % args.name)

    print('Config -----')
    for arg in vars(args):
        print('%s: %s' % (arg, getattr(args, arg)))
    print('------------')

    joblib.dump(args, 'models/%s/args.pkl' % args.name)

    # create model
    print("=> creating model %s" % args.arch)
    model = mymodel.__dict__[args.arch](args)

    model = model.cuda()

    # Data loading code
    img_paths = glob(
        r'D:\Project\CollegeDesign\dataset\Brats2018FoulModel2D\testImage\*')
    mask_paths = glob(
        r'D:\Project\CollegeDesign\dataset\Brats2018FoulModel2D\testMask\*')

    val_img_paths = img_paths
    val_mask_paths = mask_paths

    #train_img_paths, val_img_paths, train_mask_paths, val_mask_paths = \
    #   train_test_split(img_paths, mask_paths, test_size=0.2, random_state=41)

    model.load_state_dict(torch.load('models/%s/model.pth' % args.name))
    model.eval()

    val_dataset = Dataset(args, val_img_paths, val_mask_paths)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             drop_last=False)

    if val_args.mode == "GetPicture":
        """
        获取并保存模型生成的标签图
        """
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')

            with torch.no_grad():
                for i, (input, target) in tqdm(enumerate(val_loader),
                                               total=len(val_loader)):
                    input = input.cuda()
                    #target = target.cuda()

                    # compute output
                    if args.deepsupervision:
                        output = model(input)[-1]
                    else:
                        output = model(input)
                    #print("img_paths[i]:%s" % img_paths[i])
                    output = torch.sigmoid(output).data.cpu().numpy()
                    img_paths = val_img_paths[args.batch_size *
                                              i:args.batch_size * (i + 1)]
                    #print("output_shape:%s"%str(output.shape))

                    for i in range(output.shape[0]):
                        """
                        生成灰色圖片
                        wtName = os.path.basename(img_paths[i])
                        overNum = wtName.find(".npy")
                        wtName = wtName[0:overNum]
                        wtName = wtName + "_WT" + ".png"
                        imsave('output/%s/'%args.name + wtName, (output[i,0,:,:]*255).astype('uint8'))
                        tcName = os.path.basename(img_paths[i])
                        overNum = tcName.find(".npy")
                        tcName = tcName[0:overNum]
                        tcName = tcName + "_TC" + ".png"
                        imsave('output/%s/'%args.name + tcName, (output[i,1,:,:]*255).astype('uint8'))
                        etName = os.path.basename(img_paths[i])
                        overNum = etName.find(".npy")
                        etName = etName[0:overNum]
                        etName = etName + "_ET" + ".png"
                        imsave('output/%s/'%args.name + etName, (output[i,2,:,:]*255).astype('uint8'))
                        """
                        npName = os.path.basename(img_paths[i])
                        overNum = npName.find(".npy")
                        rgbName = npName[0:overNum]
                        rgbName = rgbName + ".png"
                        rgbPic = np.zeros([160, 160, 3], dtype=np.uint8)
                        for idx in range(output.shape[2]):
                            for idy in range(output.shape[3]):
                                if output[i, 0, idx, idy] > 0.5:
                                    rgbPic[idx, idy, 0] = 0
                                    rgbPic[idx, idy, 1] = 128
                                    rgbPic[idx, idy, 2] = 0
                                if output[i, 1, idx, idy] > 0.5:
                                    rgbPic[idx, idy, 0] = 255
                                    rgbPic[idx, idy, 1] = 0
                                    rgbPic[idx, idy, 2] = 0
                                if output[i, 2, idx, idy] > 0.5:
                                    rgbPic[idx, idy, 0] = 255
                                    rgbPic[idx, idy, 1] = 255
                                    rgbPic[idx, idy, 2] = 0
                        imsave('output/%s/' % args.name + rgbName, rgbPic)

            torch.cuda.empty_cache()
        """
        将验证集中的GT numpy格式转换成图片格式并保存
        """
        print("Saving GT,numpy to picture")
        val_gt_path = 'output/%s/' % args.name + "GT/"
        if not os.path.exists(val_gt_path):
            os.mkdir(val_gt_path)
        for idx in tqdm(range(len(val_mask_paths))):
            mask_path = val_mask_paths[idx]
            name = os.path.basename(mask_path)
            overNum = name.find(".npy")
            name = name[0:overNum]
            rgbName = name + ".png"

            npmask = np.load(mask_path)

            GtColor = np.zeros([npmask.shape[0], npmask.shape[1], 3],
                               dtype=np.uint8)
            for idx in range(npmask.shape[0]):
                for idy in range(npmask.shape[1]):
                    #坏疽(NET,non-enhancing tumor)(标签1) 红色
                    if npmask[idx, idy] == 1:
                        GtColor[idx, idy, 0] = 255
                        GtColor[idx, idy, 1] = 0
                        GtColor[idx, idy, 2] = 0
                    #浮肿区域(ED,peritumoral edema) (标签2) 绿色
                    elif npmask[idx, idy] == 2:
                        GtColor[idx, idy, 0] = 0
                        GtColor[idx, idy, 1] = 128
                        GtColor[idx, idy, 2] = 0
                    #增强肿瘤区域(ET,enhancing tumor)(标签4) 黄色
                    elif npmask[idx, idy] == 4:
                        GtColor[idx, idy, 0] = 255
                        GtColor[idx, idy, 1] = 255
                        GtColor[idx, idy, 2] = 0

            #imsave(val_gt_path + rgbName, GtColor)
            imageio.imwrite(val_gt_path + rgbName, GtColor)
            """
            mask_path = val_mask_paths[idx]
            name = os.path.basename(mask_path)
            overNum = name.find(".npy")
            name = name[0:overNum]
            wtName = name + "_WT" + ".png"
            tcName = name + "_TC" + ".png"
            etName = name + "_ET" + ".png"

            npmask = np.load(mask_path)

            WT_Label = npmask.copy()
            WT_Label[npmask == 1] = 1.
            WT_Label[npmask == 2] = 1.
            WT_Label[npmask == 4] = 1.
            TC_Label = npmask.copy()
            TC_Label[npmask == 1] = 1.
            TC_Label[npmask == 2] = 0.
            TC_Label[npmask == 4] = 1.
            ET_Label = npmask.copy()
            ET_Label[npmask == 1] = 0.
            ET_Label[npmask == 2] = 0.
            ET_Label[npmask == 4] = 1.

            imsave(val_gt_path + wtName, (WT_Label * 255).astype('uint8'))
            imsave(val_gt_path + tcName, (TC_Label * 255).astype('uint8'))
            imsave(val_gt_path + etName, (ET_Label * 255).astype('uint8'))
            """
        print("Done!")

    if val_args.mode == "Calculate":
        """
        计算各种指标:Dice、Sensitivity、PPV
        """
        wt_dices = []
        tc_dices = []
        et_dices = []
        wt_sensitivities = []
        tc_sensitivities = []
        et_sensitivities = []
        wt_ppvs = []
        tc_ppvs = []
        et_ppvs = []
        wt_Hausdorf = []
        tc_Hausdorf = []
        et_Hausdorf = []

        wtMaskList = []
        tcMaskList = []
        etMaskList = []
        wtPbList = []
        tcPbList = []
        etPbList = []

        maskPath = glob("output/%s/" % args.name + "GT\*.png")
        pbPath = glob("output/%s/" % args.name + "*.png")
        if len(maskPath) == 0:
            print("请先生成图片!")
            return

        for myi in tqdm(range(len(maskPath))):
            mask = imread(maskPath[myi])
            pb = imread(pbPath[myi])

            wtmaskregion = np.zeros([mask.shape[0], mask.shape[1]],
                                    dtype=np.float32)
            wtpbregion = np.zeros([mask.shape[0], mask.shape[1]],
                                  dtype=np.float32)

            tcmaskregion = np.zeros([mask.shape[0], mask.shape[1]],
                                    dtype=np.float32)
            tcpbregion = np.zeros([mask.shape[0], mask.shape[1]],
                                  dtype=np.float32)

            etmaskregion = np.zeros([mask.shape[0], mask.shape[1]],
                                    dtype=np.float32)
            etpbregion = np.zeros([mask.shape[0], mask.shape[1]],
                                  dtype=np.float32)

            for idx in range(mask.shape[0]):
                for idy in range(mask.shape[1]):
                    # 只要这个像素的任何一个通道有值,就代表这个像素不属于前景,即属于WT区域
                    if mask[idx, idy, :].any() != 0:
                        wtmaskregion[idx, idy] = 1
                    if pb[idx, idy, :].any() != 0:
                        wtpbregion[idx, idy] = 1
                    # 只要第一个通道是255,即可判断是TC区域,因为红色和黄色的第一个通道都是255,区别于绿色
                    if mask[idx, idy, 0] == 255:
                        tcmaskregion[idx, idy] = 1
                    if pb[idx, idy, 0] == 255:
                        tcpbregion[idx, idy] = 1
                    # 只要第二个通道是128,即可判断是ET区域
                    if mask[idx, idy, 1] == 128:
                        etmaskregion[idx, idy] = 1
                    if pb[idx, idy, 1] == 128:
                        etpbregion[idx, idy] = 1
            #开始计算WT
            dice = dice_coef(wtpbregion, wtmaskregion)
            wt_dices.append(dice)
            ppv_n = ppv(wtpbregion, wtmaskregion)
            wt_ppvs.append(ppv_n)
            Hausdorff = hausdorff_distance(wtmaskregion, wtpbregion)
            wt_Hausdorf.append(Hausdorff)
            sensitivity_n = sensitivity(wtpbregion, wtmaskregion)
            wt_sensitivities.append(sensitivity_n)
            # 开始计算TC
            dice = dice_coef(tcpbregion, tcmaskregion)
            tc_dices.append(dice)
            ppv_n = ppv(tcpbregion, tcmaskregion)
            tc_ppvs.append(ppv_n)
            Hausdorff = hausdorff_distance(tcmaskregion, tcpbregion)
            tc_Hausdorf.append(Hausdorff)
            sensitivity_n = sensitivity(tcpbregion, tcmaskregion)
            tc_sensitivities.append(sensitivity_n)
            # 开始计算ET
            dice = dice_coef(etpbregion, etmaskregion)
            et_dices.append(dice)
            ppv_n = ppv(etpbregion, etmaskregion)
            et_ppvs.append(ppv_n)
            Hausdorff = hausdorff_distance(etmaskregion, etpbregion)
            et_Hausdorf.append(Hausdorff)
            sensitivity_n = sensitivity(etpbregion, etmaskregion)
            et_sensitivities.append(sensitivity_n)

        print('WT Dice: %.4f' % np.mean(wt_dices))
        print('TC Dice: %.4f' % np.mean(tc_dices))
        print('ET Dice: %.4f' % np.mean(et_dices))
        print("=============")
        print('WT PPV: %.4f' % np.mean(wt_ppvs))
        print('TC PPV: %.4f' % np.mean(tc_ppvs))
        print('ET PPV: %.4f' % np.mean(et_ppvs))
        print("=============")
        print('WT sensitivity: %.4f' % np.mean(wt_sensitivities))
        print('TC sensitivity: %.4f' % np.mean(tc_sensitivities))
        print('ET sensitivity: %.4f' % np.mean(et_sensitivities))
        print("=============")
        print('WT Hausdorff: %.4f' % np.mean(wt_Hausdorf))
        print('TC Hausdorff: %.4f' % np.mean(tc_Hausdorf))
        print('ET Hausdorff: %.4f' % np.mean(et_Hausdorf))
        print("=============")
Example #19
0
def main():
    args = parse_args()
    config_file = "../configs/config_SN7.json"
    config_dict = json.loads(open(config_file, 'rt').read())
    #config_dict = json.loads(open(sys.argv[1], 'rt').read())

    file_dict = config_dict['file_path']
    val_config = config_dict['val_config']

    name = val_config['name']
    input_folder  =file_dict['input_path'] # '../inputs'
    model_folder = file_dict['model_path']  # '../models'
    output_folder = file_dict['output_path']  # '../models'

    ss_unet_GAN = True
    # create model
    if ss_unet_GAN == False:
        path = os.path.join(model_folder, '%s/config.yml' % name)
        with open(os.path.join(model_folder, '%s/config.yml' % name), 'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        config['name'] = name
        print('-' * 20)
        for key in config.keys():
            print('%s: %s' % (key, str(config[key])))
        print('-' * 20)
        cudnn.benchmark = True
        print("=> creating model %s" % config['arch'])
        model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])
        model = model.cuda()

        #img_ids = glob(os.path.join(input_folder, config['dataset'], 'images', '*' + config['img_ext']))
        #img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
        #_, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)
        model_dict = torch.load(os.path.join(model_folder,'%s/model.pth' %config['name']))
        if "state_dict" in model_dict.keys():
            model_dict = remove_prefix(model_dict['state_dict'], 'module.')
        else:
            model_dict = remove_prefix(model_dict, 'module.')
        model.load_state_dict(model_dict, strict=False)
        #model.load_state_dict(torch.load(os.path.join(model_folder,'%s/model.pth' %config['name'])))
        model.eval()
    else:
        val_config = config_dict['val_config']
        generator_name = val_config['name']
        with open(os.path.join(model_folder, '%s/config.yml' % generator_name), 'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        generator = Generator(config)
        generator = generator.cuda()
        '''
        with open(os.path.join(model_folder, '%s/config.yml' % name), 'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        '''
        config['name'] = name
        model_dict = torch.load(os.path.join(model_folder,'%s/model.pth' %config['name']))
        if "state_dict" in model_dict.keys():
            model_dict = remove_prefix(model_dict['state_dict'], 'module.')
        else:
            model_dict = remove_prefix(model_dict, 'module.')
        generator.load_state_dict(model_dict, strict=False)
        #model.load_state_dict(torch.load(os.path.join(model_folder,'%s/model.pth' %config['name'])))
        generator.eval()

    # Data loading code
    img_ids = glob(os.path.join(input_folder, config['val_dataset'], 'images','test', '*' + config['img_ext']))
    val_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    val_transform = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        #transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
        transforms.Normalize(mean=mean, std=std),
    ])


    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join(input_folder, config['val_dataset'], 'images','test'),
        mask_dir=os.path.join(input_folder, config['val_dataset'], 'annotations','test'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        input_channels=config['input_channels'],
        transform=val_transform)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1, #config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)

    avg_meters = {'iou': AverageMeter(),
                  'dice' : AverageMeter()}

    num_classes = config['num_classes']
    for c in range(config['num_classes']):
        os.makedirs(os.path.join( output_folder, config['name'], str(c)), exist_ok=True)

    csv_save_name = os.path.join(output_folder, config['name'] + '_result' + '.csv')
    result_submission = []
    with torch.no_grad():
        pbar = tqdm(total=len(val_loader))
        for ori_img, input, target, targets,  meta in val_loader:
            input = input.cuda()
            target = target.cuda()

            # compute output
            if ss_unet_GAN == True:
                if config['deep_supervision']:
                    output = generator(input)[-1]
                else:
                    output = generator(input)
            else:
                if config['deep_supervision']:
                    output = model(input)[-1]
                else:
                    output = model(input)
            out_m = output[:, 1:num_classes, :, :].clone()
            tar_m = target[:, 1:num_classes, :, :].clone()
            iou = iou_score(out_m, tar_m)
            dice = dice_coef(out_m, tar_m)
            result_submission.append([meta['img_id'][0], iou, dice])

            avg_meters['iou'].update(iou, input.size(0))
            avg_meters['dice'].update(dice, input.size(0))
            output = torch.sigmoid(output).cpu().numpy()
            masks = target.cpu()
            for i in range(len(output)):

                for idx_c in range(num_classes):
                    tmp_mask = np.array(masks[i][idx_c])
                    mask = np.array(255 * tmp_mask).astype('uint8')
                    mask_out = np.array(255 * output[i][idx_c]).astype('uint8')
                    mask_output = np.zeros((mask_out.shape[0], mask_out.shape[1]))
                    mask_output = mask_output.astype('uint8')
                    mask_ = mask_out > 127
                    mask_output[mask_] = 255

                    if idx_c >0:
                        save_GT_RE_mask(output_folder, config, meta, idx_c, i, ori_img, mask, mask_output)

            postfix = OrderedDict([
                ('iou', avg_meters['iou'].avg),
                ('dice', avg_meters['dice'].avg),
            ])
            pbar.set_postfix(postfix)
            pbar.update(1)
        pbar.close()

    result_save_to_csv_filename(csv_save_name, result_submission)
    print('IoU: %.4f' % avg_meters['iou'].avg)
    print('dice: %.4f' % avg_meters['dice'].avg)

    torch.cuda.empty_cache()
Example #20
0
def train(epoch, config, train_loader, generator, discriminator, criterion,
          adversarial_loss_criterion, content_loss_criterion, optimizer_g,
          optimizer_d):
    avg_meters = {
        'loss': AverageMeter(),
        'iou': AverageMeter(),
        'dice': AverageMeter()
    }
    alpa = 1e-4
    beta = 1e-3
    grad_clip = 0.8
    generator.train()
    discriminator.train()
    lr_val = optimizer_g.param_groups[0]['lr']
    print('generator learning rate {:d}: {:f}'.format(epoch, lr_val))
    pbar = tqdm(total=len(train_loader))
    num_class = int(config['num_classes'])

    for ori_img, input, target, targets, _ in train_loader:
        input = input.cuda()
        target = target.cuda()

        # compute output
        # input[torch.isnan(input)] = 0
        generator_output = generator(input)  # (N, 3, 96, 96), in [-1, 1]
        # l1_loss = masked_L1_loss(input, target, output)
        generator_output[torch.isnan(generator_output)] = 0
        out_m = generator_output[:, 1:num_class, :, :].clone()
        tar_m = target[:, 1:num_class, :, :].clone()

        loss = criterion(generator_output, target)
        content_loss = content_loss_criterion(generator_output, target)
        # loss = criterion(out_m, tar_m)
        iou = iou_score(out_m, tar_m)
        dice = dice_coef(out_m, tar_m)
        # iou = iou_score(output, target)
        # dice = dice_coef(output, target)

        seg_discriminated = discriminator(generator_output)  # (N)

        adversarial_loss = adversarial_loss_criterion(
            seg_discriminated, torch.ones_like(seg_discriminated))
        perceptual_loss = loss + alpa * content_loss + beta * adversarial_loss
        # Back-prop.
        optimizer_g.zero_grad()
        perceptual_loss.backward()

        # Clip gradients, if necessary
        if grad_clip is not None:
            clip_gradient(optimizer_g, grad_clip)

        # Update generator
        optimizer_g.step()

        hr_discriminated = discriminator(target)
        sr_discriminated = discriminator(generator_output.detach())

        # Binary Cross-Entropy loss
        adversarial_loss = adversarial_loss_criterion(sr_discriminated, torch.zeros_like(sr_discriminated)) + \
                           adversarial_loss_criterion(hr_discriminated, torch.ones_like(hr_discriminated))

        # Back-prop.
        optimizer_d.zero_grad()
        adversarial_loss.backward()

        # Clip gradients, if necessary
        if grad_clip is not None:
            clip_gradient(optimizer_d, grad_clip)

        # Update discriminator
        optimizer_d.step()

        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))
        avg_meters['dice'].update(dice, input.size(0))

        postfix = OrderedDict([
            ('loss', avg_meters['loss'].avg),
            ('iou', avg_meters['iou'].avg),
            ('dice', avg_meters['dice'].avg),
        ])
        pbar.set_postfix(postfix)
        pbar.update(1)
    pbar.close()

    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg)])
Example #21
0
        l = ifilterfalse(isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n


if __name__ == '__main__':
    predict = torch.tensor(
        ((0.01, 0.03, 0.02, 0.02), (0.05, 0.12, 0.09, 0.07),
         (0.89, 0.85, 0.88, 0.91), (0.99, 0.97, 0.95, 0.97)),
        dtype=torch.float)

    target = torch.tensor(
        ((0, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1), (1, 1, 1, 1)),
        dtype=torch.float)

    inter = torch.sum(predict * target)

    dice = 2 * inter / (torch.sum(predict) + torch.sum(target))

    dice_coef(predict.unsqueeze(0), target.unsqueeze(0))
Example #22
0
YP = np.divide(YP, len(modelName))

# YP = unshuffle(YP,tidx)

dc = []
dc3D = []
for i in range(YP.shape[0]):
	yp = YP[i,...]
	ygt = Y[i,...]

	if len(yp.shape) > 3:
		dcSlice=[]
		for j in range(yp.shape[2]):
			cenp = yp[:,:,j,:]
			gtp = ygt[:,:,j,:]
			dcSlice.append(m.dice_coef(gtp,cenp,threshold=segThreshold))

		dc.append(dcSlice)

	dc3D.append(m.dice_coef(ygt, yp,threshold=segThreshold))

if len(yp.shape) > 3:
	dc = np.array(dc)
	print("Mean DC : \n")
	print(np.mean(dc,axis=0))

dc3D = np.array(dc3D)
print("\nMean 3D DC : \n")
print(np.mean(dc3D))

'''
Example #23
0
def dice_coef_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)
Example #24
0
def main():
    val_args = parse_args()

    args = joblib.load('models/%s/args.pkl' % val_args.name)
    if not os.path.exists('output/%s' % args.name):
        os.makedirs('output/%s' % args.name)
    print('Config -----')
    for arg in vars(args):
        print('%s: %s' % (arg, getattr(args, arg)))
    print('------------')
    joblib.dump(args, 'models/%s/args.pkl' % args.name)

    # create model
    print("=> creating model %s" % args.arch)
    model = UNet_3Plus.__dict__[args.arch](args)

    model = model.cuda()

    # Data loading code
    img_paths = glob(r'E:\Code\GZ-UNet3+\testImage\*')
    mask_paths = glob(r'E:\Code\GZ-UNet3+\testMask\*')

    val_img_paths = img_paths
    val_mask_paths = mask_paths

    model.load_state_dict(torch.load('models/%s/model.pth' % args.name))
    model.eval()
    #model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')

    val_dataset = Dataset(args, val_img_paths, val_mask_paths)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             drop_last=False)

    savedir = 'output/%s/' % args.name
    if not os.path.exists(savedir):
        os.mkdir(savedir)

    with warnings.catch_warnings():
        warnings.simplefilter('ignore')

        tumor_dice = []

        with torch.no_grad():
            startFlag = 1
            for mynum, (input, target) in tqdm(enumerate(val_loader),
                                               total=len(val_loader)):
                input = input.cuda()
                target = target.cuda()
                output = model(input)
                # output = torch.sigmoid(output).data.cpu().numpy()
                # target = target.data.cpu().numpy()
                # img_paths = val_img_paths[args.batch_size * mynum:args.batch_size * (mynum + 1)]
                dice = dice_coef(output, target)
                tumor_dice.append(dice)

        torch.cuda.empty_cache()

        print("=============")
        print('Tumor Dice: %.4f' % np.mean(tumor_dice))
        print("=============")
    '''
Example #25
0
def perform_validation(modelName, testNum, fileName):
    #args = parse_args()

    fw = open('batch_results_val/' + fileName, 'w')
    #with open('models/%s/config.yml' % args.name, 'r') as f:
    with open('models/%s/config.yml' % modelName, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    config['dataset'] = 'ax_crop_val_' + str(testNum) + '_' + str(testNum + 1)

    print('-' * 20)
    fw.write('-' * 20 + '\n')
    for key in config.keys():
        print('%s: %s' % (key, str(config[key])))
        fw.write('%s: %s' % (key, str(config[key])) + '\n')
    print('-' * 20)
    fw.write('-' * 20 + '\n')

    cudnn.benchmark = True

    # create model
    print("=> creating model %s" % config['arch'])
    fw.write("=> creating model %s" % config['arch'] + '\n')
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])

    model = model.cuda()

    # Data loading code
    img_ids = glob(
        os.path.join('inputs', config['dataset'], 'images',
                     '*' + config['img_ext']))
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    _, val_img_ids = train_test_split(img_ids, test_size=0.99, random_state=41)

    model.load_state_dict(torch.load('models/%s/model.pth' % config['name']))
    model.eval()

    val_transform = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

    val_dataset = Dataset(img_ids=val_img_ids,
                          img_dir=os.path.join('inputs', config['dataset'],
                                               'images'),
                          mask_dir=os.path.join('inputs', config['dataset'],
                                                'masks'),
                          img_ext=config['img_ext'],
                          mask_ext=config['mask_ext'],
                          num_classes=config['num_classes'],
                          transform=val_transform)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=config['batch_size'],
                                             shuffle=False,
                                             num_workers=config['num_workers'],
                                             drop_last=False)

    avg_meter = AverageMeter()
    dice_avg_meter = AverageMeter()

    for c in range(config['num_classes']):
        os.makedirs(os.path.join('outputs', config['name'], str(c)),
                    exist_ok=True)
    with torch.no_grad():
        for input, target, meta in tqdm(val_loader, total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()

            # compute output
            if config['deep_supervision']:
                output = model(input)[-1]
            else:
                output = model(input)

            iou = iou_score(output, target)
            avg_meter.update(iou, input.size(0))

            dice = dice_coef(output, target)
            dice_avg_meter.update(dice, input.size(0))

            output = torch.sigmoid(output).cpu().numpy()

            for i in range(len(output)):
                for c in range(config['num_classes']):
                    cv2.imwrite(
                        os.path.join('outputs', config['name'], str(c),
                                     meta['img_id'][i] + '.jpg'),
                        (output[i, c] * 255).astype('uint8'))

    print('IoU: %.4f' % avg_meter.avg)
    fw.write('IoU: %.4f' % avg_meter.avg)
    print('Dice: %.4f' % dice_avg_meter.avg)
    fw.write('Dice: %.4f' % dice_avg_meter.avg)

    torch.cuda.empty_cache()
Example #26
0
def unet_loss(y_true, y_pred):
    from keras.losses import binary_crossentropy
    from metrics import dice_coef
    return .5 * binary_crossentropy(y_true, y_pred) - dice_coef(y_true, y_pred)
Example #27
0
 def forward(self, inputs, target):
     return 1.0 - dice_coef(inputs, target, smooth=self.smooth)
Example #28
0
 def forward(self, inputs, target):
     return -torch.log(dice_coef(inputs, target, self.smooth))
Example #29
0
    transform_path0 = os.path.join(result_path, 'TransformParameters.0.txt')
    transform_path1 = os.path.join(result_path, 'TransformParameters.1.txt')
    final_transform_path = os.path.join(result_path, 'transform_pathfinal.txt')

    # Change FinalBSplineInterpolationOrder to 0 for binary mask transformation
    TransformParameterFileEditor(transform_path1, transform_path0, final_transform_path).modify_transform_parameter_file()

    # Make a new transformix object tr with the CORRECT PATH to transformix
    tr = elastix.TransformixInterface(parameters=final_transform_path,
                                      transformix_path=TRANSFORMIX_PATH)

    transformed_pr_path = tr.transform_image(pr_image_path, output_dir=result_path)
    image_array_tpr = sitk.GetArrayFromImage(sitk.ReadImage(transformed_pr_path))

    log_path = os.path.join(result_path, 'IterationInfo.1.R3.txt')
    log = elastix.logfile(log_path)

    DSC.append(dice_coef(image_array_opr, image_array_tpr))
    SNS.append(sensitivity(image_array_opr, image_array_tpr))
    SPC.append(specificity(image_array_opr, image_array_tpr))
    finalMI.append(statistics.mean(log['metric'][-50:-1]))

fig, (ax1,ax2,ax3) = plt.subplots(1, 3, figsize=(15, 5))
ax1.scatter(finalMI,DSC)
ax1.set_title("DSC")
ax2.scatter(finalMI,SNS)
ax2.set_title("SNS")
ax3.scatter(finalMI,SPC)
ax3.set_title("SPC")
plt.show()
def get_scores(y_true, y_predict):
    return dice_coef(y_true, y_predict), sensitivity(y_true, y_predict), specificity(y_true, y_predict), MeanSurfaceDistance(y_true, y_predict), mutual_information(y_true, y_predict), rmse(y_true, y_predict)