Пример #1
0
 def get_description_widget(self):
     widget = self.DESCRIPTION_WIDGET_CLASS()
     for i in xrange(self.SIZE):
         value = tuple(get_random_color())
         self.sequence.append(value)
         widget.sequence_widget.add_widget(ColorTile(color=value))
     return widget
Пример #2
0
def MatrixElement(i, j, windows, elem=None):
    text = elem if elem and elem != '0' else 0
    return sg.Button(text,
                     size=(1, 1),
                     key=(i, j, windows),
                     pad=(1, 1),
                     button_color=('#F0F0F0', get_random_color(elem)))
Пример #3
0
    def generate_alphabet(self):
        from kivy.utils import get_hex_from_color
        from utils import get_random_color

        alphabet = []
        for i in xrange(self.SIZE):
            value = tuple(get_random_color())
            alphabet.append(get_hex_from_color(value))
        return alphabet
Пример #4
0
 def __init__(self,
              xy_data,
              machine_id,
              title,
              label,
              line_color=None,
              point_color=None,
              fill=False,
              width=None,
              height=None):
     """
     :param xy_data: Function or LambdaFunction object that
         returns tuple of two elements.
         Firs element -- iterable that contains X axis data.
         Second element -- iterable that contains Y axis data.
         Elements must be the same length
     :param machine_id: ID of the CNC machine
     :param title: Chart title
     :param label: Label of the Y axis
     :param line_color: Color of the line
     :param point_color: Color of the points on the line
     :param fill: if True than area under the line will be filled
         or not if taken False.
     :param width: width of the html canvas where chart is situated
     :param height: width of the html canvas where chart is situated
     """
     assert callable(xy_data)
     assert (isinstance(xy_data, types.FunctionType)
             or isinstance(xy_data, types.LambdaType))
     self.xy_data = xy_data
     self.machine_id = machine_id
     self.id = uuid.uuid4()
     self.title = title
     self.label = label
     self.line_color = line_color or get_random_color()
     self.point_color = point_color or get_random_color()
     self.axis_id = get_random_chars()
     self.fill = fill
     self.width = width or 60
     self.height = height or 30
Пример #5
0
async def receive_signals(window, proc, key):
    number = 0
    while True:
        data = await proc.stderr.readline()
        proc.communicate()
        message = data.decode('ascii').rstrip()
        if not message:
            break
        matrix = json.loads(message)
        for i, row in enumerate(matrix):
            for j, elem in enumerate(row):
                window[(i, j,
                        key)].update(elem,
                                     button_color=('#F0F0F0',
                                                   get_random_color(elem)))
        number += 1
        # maybe we can remove this, atm is too fast
        await asyncio.sleep(0.1)
    # Wait for the subprocess exit.
    await proc.wait()
Пример #6
0
 def generate_alphabet(self):
     alphabet = []
     for i in xrange(self.SIZE**2):
         value = tuple(get_random_color())
         alphabet.append(get_hex_from_color(value))
     return alphabet
Пример #7
0
def main():
    opt = Options(isTrain=False)
    opt.parse()
    opt.save_options()

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.test['gpus'])

    img_dir = opt.test['img_dir']
    label_dir = opt.test['label_dir']
    save_dir = opt.test['save_dir']
    model_path = opt.test['model_path']
    save_flag = opt.test['save_flag']

    # data transforms
    test_transform = get_transforms(opt.transform['test'])

    model = ResUNet34(pretrained=opt.model['pretrained'])
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    cudnn.benchmark = True

    # ----- load trained model ----- #
    print("=> loading trained model")
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded model at epoch {}".format(checkpoint['epoch']))
    model = model.module

    # switch to evaluate mode
    model.eval()
    counter = 0
    print("=> Test begins:")

    img_names = os.listdir(img_dir)

    if save_flag:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        strs = img_dir.split('/')
        prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1])
        seg_folder = '{:s}/{:s}_segmentation'.format(save_dir, strs[-1])
        if not os.path.exists(prob_maps_folder):
            os.mkdir(prob_maps_folder)
        if not os.path.exists(seg_folder):
            os.mkdir(seg_folder)

    metric_names = ['acc', 'p_F1', 'p_recall', 'p_precision', 'dice', 'aji']
    test_results = dict()
    all_result = utils.AverageMeter(len(metric_names))

    for img_name in img_names:
        # load test image
        print('=> Processing image {:s}'.format(img_name))
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
        gt = misc.imread(label_path)

        input = test_transform((img, ))[0].unsqueeze(0)

        print('\tComputing output probability maps...')
        prob_maps = get_probmaps(input, model, opt)
        pred = np.argmax(prob_maps, axis=0)  # prediction

        pred_labeled = measure.label(pred)
        pred_labeled = morph.remove_small_objects(pred_labeled,
                                                  opt.post['min_area'])
        pred_labeled = ndi_morph.binary_fill_holes(pred_labeled > 0)
        pred_labeled = measure.label(pred_labeled)

        print('\tComputing metrics...')
        metrics = compute_metrics(pred_labeled, gt, metric_names)

        # save result for each image
        test_results[name] = [
            metrics['acc'], metrics['p_F1'], metrics['p_recall'],
            metrics['p_precision'], metrics['dice'], metrics['aji']
        ]

        # update the average result
        all_result.update([
            metrics['acc'], metrics['p_F1'], metrics['p_recall'],
            metrics['p_precision'], metrics['dice'], metrics['aji']
        ])

        # save image
        if save_flag:
            print('\tSaving image results...')
            misc.imsave('{:s}/{:s}_pred.png'.format(prob_maps_folder, name),
                        pred.astype(np.uint8) * 255)
            misc.imsave('{:s}/{:s}_prob.png'.format(prob_maps_folder, name),
                        prob_maps[1, :, :])
            final_pred = Image.fromarray(pred_labeled.astype(np.uint16))
            final_pred.save('{:s}/{:s}_seg.tiff'.format(seg_folder, name))

            # save colored objects
            pred_colored_instance = np.zeros((ori_h, ori_w, 3))
            for k in range(1, pred_labeled.max() + 1):
                pred_colored_instance[pred_labeled == k, :] = np.array(
                    utils.get_random_color())
            filename = '{:s}/{:s}_seg_colored.png'.format(seg_folder, name)
            misc.imsave(filename, pred_colored_instance)

        counter += 1
        if counter % 10 == 0:
            print('\tProcessed {:d} images'.format(counter))

    print('=> Processed all {:d} images'.format(counter))
    print('Average Acc: {r[0]:.4f}\nF1: {r[1]:.4f}\nRecall: {r[2]:.4f}\n'
          'Precision: {r[3]:.4f}\nDice: {r[4]:.4f}\nAJI: {r[5]:.4f}\n'.format(
              r=all_result.avg))

    header = metric_names
    utils.save_results(header, all_result.avg, test_results,
                       '{:s}/test_results.txt'.format(save_dir))
Пример #8
0
def main():
    opt = Options(isTrain=False)
    opt.parse()
    opt.save_options()
    opt.print_options()

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.test['gpu'])

    img_dir = opt.test['img_dir']
    label_dir = opt.test['label_dir']
    save_dir = opt.test['save_dir']
    model_path = opt.test['model_path']
    save_flag = opt.test['save_flag']
    tta = opt.test['tta']

    # check if it is needed to compute accuracies
    eval_flag = True if label_dir else False

    # data transforms
    test_transform = get_transforms(opt.transform['test'])

    # load model
    model = FullNet(opt.model['in_c'],
                    opt.model['out_c'],
                    n_layers=opt.model['n_layers'],
                    growth_rate=opt.model['growth_rate'],
                    drop_rate=opt.model['drop_rate'],
                    dilations=opt.model['dilations'],
                    is_hybrid=opt.model['is_hybrid'],
                    compress_ratio=opt.model['compress_ratio'],
                    layer_type=opt.model['layer_type'])
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True

    # ----- load trained model ----- #
    print("=> loading trained model")
    best_checkpoint = torch.load(model_path)
    model.load_state_dict(best_checkpoint['state_dict'])
    print("=> loaded model at epoch {}".format(best_checkpoint['epoch']))
    model = model.module

    # switch to evaluate mode
    model.eval()
    counter = 0
    print("=> Test begins:")

    img_names = os.listdir(img_dir)

    # pixel_accu, recall, precision, F1, dice, iou, haus, (AJI)
    num_metrics = 8 if opt.dataset == 'MultiOrgan' else 7
    avg_results = utils.AverageMeter(num_metrics)
    all_results = dict()

    if save_flag:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        strs = img_dir.split('/')
        prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1])
        seg_folder = '{:s}/{:s}_segmentation'.format(save_dir, strs[-1])
        if not os.path.exists(prob_maps_folder):
            os.mkdir(prob_maps_folder)
        if not os.path.exists(seg_folder):
            os.mkdir(seg_folder)

    for img_name in img_names:
        # load test image
        print('=> Processing image {:s}'.format(img_name))
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        if eval_flag:
            if opt.dataset == 'MultiOrgan':
                label_path = '{:s}/{:s}.png'.format(label_dir, name)
            else:
                label_path = '{:s}/{:s}_anno.bmp'.format(label_dir, name)
            label_img = misc.imread(label_path)

        input = test_transform((img, ))[0].unsqueeze(0)

        print('\tComputing output probability maps...')
        prob_maps = get_probmaps(input, model, opt)
        if tta:
            img_hf = img.transpose(Image.FLIP_LEFT_RIGHT)  # horizontal flip
            img_vf = img.transpose(Image.FLIP_TOP_BOTTOM)  # vertical flip
            img_hvf = img_hf.transpose(
                Image.FLIP_TOP_BOTTOM)  # horizontal and vertical flips

            input_hf = test_transform(
                (img_hf, ))[0].unsqueeze(0)  # horizontal flip input
            input_vf = test_transform(
                (img_vf, ))[0].unsqueeze(0)  # vertical flip input
            input_hvf = test_transform((img_hvf, ))[0].unsqueeze(
                0)  # horizontal and vertical flip input

            prob_maps_hf = get_probmaps(input_hf, model, opt)
            prob_maps_vf = get_probmaps(input_vf, model, opt)
            prob_maps_hvf = get_probmaps(input_hvf, model, opt)

            # re flip
            prob_maps_hf = np.flip(prob_maps_hf, 2)
            prob_maps_vf = np.flip(prob_maps_vf, 1)
            prob_maps_hvf = np.flip(np.flip(prob_maps_hvf, 1), 2)

            # rotation 90 and flips
            img_r90 = img.rotate(90, expand=True)
            img_r90_hf = img_r90.transpose(
                Image.FLIP_LEFT_RIGHT)  # horizontal flip
            img_r90_vf = img_r90.transpose(
                Image.FLIP_TOP_BOTTOM)  # vertical flip
            img_r90_hvf = img_r90_hf.transpose(
                Image.FLIP_TOP_BOTTOM)  # horizontal and vertical flips

            input_r90 = test_transform((img_r90, ))[0].unsqueeze(0)
            input_r90_hf = test_transform(
                (img_r90_hf, ))[0].unsqueeze(0)  # horizontal flip input
            input_r90_vf = test_transform(
                (img_r90_vf, ))[0].unsqueeze(0)  # vertical flip input
            input_r90_hvf = test_transform((img_r90_hvf, ))[0].unsqueeze(
                0)  # horizontal and vertical flip input

            prob_maps_r90 = get_probmaps(input_r90, model, opt)
            prob_maps_r90_hf = get_probmaps(input_r90_hf, model, opt)
            prob_maps_r90_vf = get_probmaps(input_r90_vf, model, opt)
            prob_maps_r90_hvf = get_probmaps(input_r90_hvf, model, opt)

            # re flip
            prob_maps_r90 = np.rot90(prob_maps_r90, k=3, axes=(1, 2))
            prob_maps_r90_hf = np.rot90(np.flip(prob_maps_r90_hf, 2),
                                        k=3,
                                        axes=(1, 2))
            prob_maps_r90_vf = np.rot90(np.flip(prob_maps_r90_vf, 1),
                                        k=3,
                                        axes=(1, 2))
            prob_maps_r90_hvf = np.rot90(np.flip(np.flip(prob_maps_r90_hvf, 1),
                                                 2),
                                         k=3,
                                         axes=(1, 2))

            prob_maps = (prob_maps + prob_maps_hf + prob_maps_vf +
                         prob_maps_hvf + prob_maps_r90 + prob_maps_r90_hf +
                         prob_maps_r90_vf + prob_maps_r90_hvf) / 8

        pred = np.argmax(prob_maps, axis=0)  # prediction
        pred_inside = pred == 1
        pred2 = morph.remove_small_objects(
            pred_inside, opt.post['min_area'])  # remove small object

        if 'scale' in opt.transform['test']:
            pred2 = misc.imresize(pred2.astype(np.uint8) * 255, (ori_h, ori_w),
                                  interp='bilinear')
            pred2 = (pred2 > 127.5)

        pred_labeled = measure.label(pred2)  # connected component labeling
        pred_labeled = morph.dilation(pred_labeled,
                                      selem=morph.selem.disk(
                                          opt.post['radius']))

        if eval_flag:
            print('\tComputing metrics...')
            result = utils.accuracy_pixel_level(
                np.expand_dims(pred_labeled > 0, 0),
                np.expand_dims(label_img > 0, 0))
            pixel_accu = result[0]

            if opt.dataset == 'MultiOrgan':
                result_object = utils.nuclei_accuracy_object_level(
                    pred_labeled, label_img)
            else:
                result_object = utils.gland_accuracy_object_level(
                    pred_labeled, label_img)

            all_results[name] = tuple([pixel_accu, *result_object])

            # update values
            avg_results.update([pixel_accu, *result_object])

        # save image
        if save_flag:
            print('\tSaving image results...')
            misc.imsave(
                '{:s}/{:s}_prob_inside.png'.format(prob_maps_folder, name),
                prob_maps[1, :, :])
            misc.imsave(
                '{:s}/{:s}_prob_contour.png'.format(prob_maps_folder, name),
                prob_maps[2, :, :])
            final_pred = Image.fromarray(pred_labeled.astype(np.uint16))
            final_pred.save('{:s}/{:s}_seg.tiff'.format(seg_folder, name))

            # save colored objects
            pred_colored = np.zeros((ori_h, ori_w, 3))
            for k in range(1, pred_labeled.max() + 1):
                pred_colored[pred_labeled == k, :] = np.array(
                    utils.get_random_color())
            filename = '{:s}/{:s}_seg_colored.png'.format(seg_folder, name)
            misc.imsave(filename, pred_colored)

        counter += 1
        if counter % 10 == 0:
            print('\tProcessed {:d} images'.format(counter))

    print('=> Processed all {:d} images'.format(counter))
    if eval_flag:
        print('Average of all images:\n'
              'pixel_accu: {r[0]:.4f}\n'
              'recall: {r[1]:.4f}\n'
              'precision: {r[2]:.4f}\n'
              'F1: {r[3]:.4f}\n'
              'dice: {r[4]:.4f}\n'
              'iou: {r[5]:.4f}\n'
              'haus: {r[6]:.4f}'.format(r=avg_results.avg))
        if opt.dataset == 'MultiOrgan':
            print('AJI: {r[7]:.4f}'.format(r=avg_results.avg))

        strs = img_dir.split('/')
        header = [
            'pixel_acc', 'recall', 'precision', 'F1', 'Dice', 'IoU',
            'Hausdorff'
        ]
        if opt.dataset == 'MultiOrgan':
            header.append('AJI')
        save_results(header, avg_results.avg, all_results,
                     '{:s}/{:s}_test_result.txt'.format(save_dir, strs[-1]))
Пример #9
0
def main():
    params = Params()
    img_dir = params.test['img_dir']
    label_dir = params.test['label_dir']
    save_dir = params.test['save_dir']
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    model_path = params.test['model_path']
    save_flag = params.test['save_flag']
    tta = params.test['tta']

    params.save_params('{:s}/test_params.txt'.format(params.test['save_dir']),
                       test=True)

    # check if it is needed to compute accuracies
    eval_flag = True if label_dir else False
    if eval_flag:
        test_results = dict()
        # recall, precision, F1, dice, iou, haus
        tumor_result = utils.AverageMeter(7)
        lym_result = utils.AverageMeter(7)
        stroma_result = utils.AverageMeter(7)
        all_result = utils.AverageMeter(7)
        conf_matrix = np.zeros((3, 3))

    # data transforms
    test_transform = get_transforms(params.transform['test'])

    model_name = params.model['name']
    if model_name == 'ResUNet34':
        model = ResUNet34(params.model['out_c'],
                          fixed_feature=params.model['fix_params'])
    elif params.model['name'] == 'UNet':
        model = UNet(3, params.model['out_c'])
    else:
        raise NotImplementedError()
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    cudnn.benchmark = True

    # ----- load trained model ----- #
    print("=> loading trained model")
    best_checkpoint = torch.load(model_path)
    model.load_state_dict(best_checkpoint['state_dict'])
    print("=> loaded model at epoch {}".format(best_checkpoint['epoch']))
    model = model.module

    # switch to evaluate mode
    model.eval()
    counter = 0
    print("=> Test begins:")

    img_names = os.listdir(img_dir)

    if save_flag:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        strs = img_dir.split('/')
        prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1])
        seg_folder = '{:s}/{:s}_segmentation'.format(save_dir, strs[-1])
        if not os.path.exists(prob_maps_folder):
            os.mkdir(prob_maps_folder)
        if not os.path.exists(seg_folder):
            os.mkdir(seg_folder)

    # img_names = ['193-adca-5']
    # total_time = 0.0
    for img_name in img_names:
        # load test image
        print('=> Processing image {:s}'.format(img_name))
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        if eval_flag:
            label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
            gt = misc.imread(label_path)

        input = test_transform((img, ))[0].unsqueeze(0)

        print('\tComputing output probability maps...')
        prob_maps = get_probmaps(input, model, params)
        if tta:
            img_hf = img.transpose(Image.FLIP_LEFT_RIGHT)  # horizontal flip
            img_vf = img.transpose(Image.FLIP_TOP_BOTTOM)  # vertical flip
            img_hvf = img_hf.transpose(
                Image.FLIP_TOP_BOTTOM)  # horizontal and vertical flips

            input_hf = test_transform(
                (img_hf, ))[0].unsqueeze(0)  # horizontal flip input
            input_vf = test_transform(
                (img_vf, ))[0].unsqueeze(0)  # vertical flip input
            input_hvf = test_transform((img_hvf, ))[0].unsqueeze(
                0)  # horizontal and vertical flip input

            prob_maps_hf = get_probmaps(input_hf, model, params)
            prob_maps_vf = get_probmaps(input_vf, model, params)
            prob_maps_hvf = get_probmaps(input_hvf, model, params)

            # re flip
            prob_maps_hf = np.flip(prob_maps_hf, 2)
            prob_maps_vf = np.flip(prob_maps_vf, 1)
            prob_maps_hvf = np.flip(np.flip(prob_maps_hvf, 1), 2)

            # rotation 90 and flips
            img_r90 = img.rotate(90, expand=True)
            img_r90_hf = img_r90.transpose(
                Image.FLIP_LEFT_RIGHT)  # horizontal flip
            img_r90_vf = img_r90.transpose(
                Image.FLIP_TOP_BOTTOM)  # vertical flip
            img_r90_hvf = img_r90_hf.transpose(
                Image.FLIP_TOP_BOTTOM)  # horizontal and vertical flips

            input_r90 = test_transform((img_r90, ))[0].unsqueeze(0)
            input_r90_hf = test_transform(
                (img_r90_hf, ))[0].unsqueeze(0)  # horizontal flip input
            input_r90_vf = test_transform(
                (img_r90_vf, ))[0].unsqueeze(0)  # vertical flip input
            input_r90_hvf = test_transform((img_r90_hvf, ))[0].unsqueeze(
                0)  # horizontal and vertical flip input

            prob_maps_r90 = get_probmaps(input_r90, model, params)
            prob_maps_r90_hf = get_probmaps(input_r90_hf, model, params)
            prob_maps_r90_vf = get_probmaps(input_r90_vf, model, params)
            prob_maps_r90_hvf = get_probmaps(input_r90_hvf, model, params)

            # re flip
            prob_maps_r90 = np.rot90(prob_maps_r90, k=3, axes=(1, 2))
            prob_maps_r90_hf = np.rot90(np.flip(prob_maps_r90_hf, 2),
                                        k=3,
                                        axes=(1, 2))
            prob_maps_r90_vf = np.rot90(np.flip(prob_maps_r90_vf, 1),
                                        k=3,
                                        axes=(1, 2))
            prob_maps_r90_hvf = np.rot90(np.flip(np.flip(prob_maps_r90_hvf, 1),
                                                 2),
                                         k=3,
                                         axes=(1, 2))

            # utils.show_figures((np.array(img), np.array(img_r90_hvf),
            #                     np.swapaxes(np.swapaxes(prob_maps_r90_hvf, 0, 1), 1, 2)))

            prob_maps = (prob_maps + prob_maps_hf + prob_maps_vf +
                         prob_maps_hvf + prob_maps_r90 + prob_maps_r90_hf +
                         prob_maps_r90_vf + prob_maps_r90_hvf) / 8

        pred = np.argmax(prob_maps, axis=0)  # prediction
        pred_inside = pred.copy()
        pred_inside[pred == 4] = 0  # set contours to background
        pred_nuclei_inside_labeled = measure.label(pred_inside > 0)

        pred_tumor_inside = pred_inside == 1
        pred_lym_inside = pred_inside == 2
        pred_stroma_inside = pred_inside == 3
        pred_3types_inside = pred_tumor_inside + pred_lym_inside * 2 + pred_stroma_inside * 3

        # find the correct class for each segmented nucleus
        N_nuclei = len(np.unique(pred_nuclei_inside_labeled))
        N_class = len(np.unique(pred_3types_inside))
        intersection = np.histogram2d(pred_nuclei_inside_labeled.flatten(),
                                      pred_3types_inside.flatten(),
                                      bins=(N_nuclei, N_class))[0]
        classes = np.argmax(intersection, axis=1)
        tumor_nuclei_indices = np.nonzero(classes == 1)
        lym_nuclei_indices = np.nonzero(classes == 2)
        stroma_nuclei_indices = np.nonzero(classes == 3)

        # solve the problem of one nucleus assigned with different labels
        pred_tumor_inside = np.isin(pred_nuclei_inside_labeled,
                                    tumor_nuclei_indices)
        pred_lym_inside = np.isin(pred_nuclei_inside_labeled,
                                  lym_nuclei_indices)
        pred_stroma_inside = np.isin(pred_nuclei_inside_labeled,
                                     stroma_nuclei_indices)

        # remove small objects
        pred_tumor_inside = morph.remove_small_objects(pred_tumor_inside,
                                                       params.post['min_area'])
        pred_lym_inside = morph.remove_small_objects(pred_lym_inside,
                                                     params.post['min_area'])
        pred_stroma_inside = morph.remove_small_objects(
            pred_stroma_inside, params.post['min_area'])

        # connected component labeling
        pred_tumor_inside_labeled = measure.label(pred_tumor_inside)
        pred_lym_inside_labeled = measure.label(pred_lym_inside)
        pred_stroma_inside_labeled = measure.label(pred_stroma_inside)
        pred_all_inside_labeled = pred_tumor_inside_labeled * 3 \
                                  + (pred_lym_inside_labeled * 3 - 2) * (pred_lym_inside_labeled>0) \
                                  + (pred_stroma_inside_labeled * 3 - 1) * (pred_stroma_inside_labeled>0)

        # dilation
        pred_tumor_labeled = morph.dilation(pred_tumor_inside_labeled,
                                            selem=morph.selem.disk(
                                                params.post['radius']))
        pred_lym_labeled = morph.dilation(pred_lym_inside_labeled,
                                          selem=morph.selem.disk(
                                              params.post['radius']))
        pred_stroma_labeled = morph.dilation(pred_stroma_inside_labeled,
                                             selem=morph.selem.disk(
                                                 params.post['radius']))
        pred_all_labeled = morph.dilation(pred_all_inside_labeled,
                                          selem=morph.selem.disk(
                                              params.post['radius']))

        # utils.show_figures([pred, pred2, pred_labeled])

        if eval_flag:
            print('\tComputing metrics...')
            gt_tumor = (gt % 3 == 0) * gt
            gt_lym = (gt % 3 == 1) * gt
            gt_stroma = (gt % 3 == 2) * gt

            tumor_detect_metrics = utils.accuracy_detection_clas(
                pred_tumor_labeled, gt_tumor, clas_flag=False)
            lym_detect_metrics = utils.accuracy_detection_clas(
                pred_lym_labeled, gt_lym, clas_flag=False)
            stroma_detect_metrics = utils.accuracy_detection_clas(
                pred_stroma_labeled, gt_stroma, clas_flag=False)
            all_detect_metrics = utils.accuracy_detection_clas(
                pred_all_labeled, gt, clas_flag=True)

            tumor_seg_metrics = utils.accuracy_object_level(
                pred_tumor_labeled, gt_tumor, hausdorff_flag=False)
            lym_seg_metrics = utils.accuracy_object_level(pred_lym_labeled,
                                                          gt_lym,
                                                          hausdorff_flag=False)
            stroma_seg_metrics = utils.accuracy_object_level(
                pred_stroma_labeled, gt_stroma, hausdorff_flag=False)
            all_seg_metrics = utils.accuracy_object_level(pred_all_labeled,
                                                          gt,
                                                          hausdorff_flag=True)

            tumor_metrics = [*tumor_detect_metrics[:-1], *tumor_seg_metrics]
            lym_metrics = [*lym_detect_metrics[:-1], *lym_seg_metrics]
            stroma_metrics = [*stroma_detect_metrics[:-1], *stroma_seg_metrics]
            all_metrics = [*all_detect_metrics[:-1], *all_seg_metrics]
            conf_matrix += np.array(all_detect_metrics[-1])

            # save result for each image
            test_results[name] = {
                'tumor': tumor_metrics,
                'lym': lym_metrics,
                'stroma': stroma_metrics,
                'all': all_metrics
            }

            # update the average result
            tumor_result.update(tumor_metrics)
            lym_result.update(lym_metrics)
            stroma_result.update(stroma_metrics)
            all_result.update(all_metrics)

        # save image
        if save_flag:
            print('\tSaving image results...')
            misc.imsave('{:s}/{:s}_pred.png'.format(prob_maps_folder, name),
                        pred.astype(np.uint8) * 50)
            misc.imsave(
                '{:s}/{:s}_prob_tumor.png'.format(prob_maps_folder, name),
                prob_maps[1, :, :])
            misc.imsave(
                '{:s}/{:s}_prob_lym.png'.format(prob_maps_folder, name),
                prob_maps[2, :, :])
            misc.imsave(
                '{:s}/{:s}_prob_stroma.png'.format(prob_maps_folder, name),
                prob_maps[3, :, :])
            # np.save('{:s}/{:s}_prob.npy'.format(prob_maps_folder, name), prob_maps)
            # np.save('{:s}/{:s}_seg.npy'.format(seg_folder, name), pred_all_labeled)
            final_pred = Image.fromarray(pred_all_labeled.astype(np.uint16))
            final_pred.save('{:s}/{:s}_seg.tiff'.format(seg_folder, name))

            # save colored objects
            pred_colored = np.zeros((ori_h, ori_w, 3))
            pred_colored_instance = np.zeros((ori_h, ori_w, 3))
            pred_colored[pred_tumor_labeled > 0] = np.array([255, 0, 0])
            pred_colored[pred_lym_labeled > 0] = np.array([0, 255, 0])
            pred_colored[pred_stroma_labeled > 0] = np.array([0, 0, 255])
            filename = '{:s}/{:s}_seg_colored_3types.png'.format(
                seg_folder, name)
            misc.imsave(filename, pred_colored)
            for k in range(1, pred_all_labeled.max() + 1):
                pred_colored_instance[pred_all_labeled == k, :] = np.array(
                    utils.get_random_color())
            filename = '{:s}/{:s}_seg_colored.png'.format(seg_folder, name)
            misc.imsave(filename, pred_colored_instance)

            # img_overlaid = utils.overlay_edges(label_img, pred_labeled2, img)
            # filename = '{:s}/{:s}_comparison.png'.format(seg_folder, name)
            # misc.imsave(filename, img_overlaid)

        counter += 1
        if counter % 10 == 0:
            print('\tProcessed {:d} images'.format(counter))

    # print('Time: {:4f}'.format(total_time/counter))

    print('=> Processed all {:d} images'.format(counter))
    if eval_flag:
        print(
            'Average: clas_acc\trecall\tprecision\tF1\tdice\tiou\thausdorff\n'
            'tumor: {t[0]:.4f}, {t[1]:.4f}, {t[2]:.4f}, {t[3]:.4f}, {t[4]:.4f}, {t[5]:.4f}, {t[6]:.4f}\n'
            'lym: {l[0]:.4f}, {l[1]:.4f}, {l[2]:.4f}, {l[3]:.4f}, {l[4]:.4f}, {l[5]:.4f}, {l[6]:.4f}\n'
            'stroma: {s[0]:.4f}, {s[1]:.4f}, {s[2]:.4f}, {s[3]:.4f}, {s[4]:.4f}, {s[5]:.4f}, {s[6]:.4f}\n'
            'all: {a[0]:.4f}, {a[1]:.4f}, {a[2]:.4f}, {a[3]:.4f}, {a[4]:.4f}, {a[5]:.4f}, {a[6]:.4f}'
            .format(t=tumor_result.avg,
                    l=lym_result.avg,
                    s=stroma_result.avg,
                    a=all_result.avg))

        header = [
            'clas_acc', 'recall', 'precision', 'F1', 'Dice', 'IoU', 'Hausdorff'
        ]
        save_results(header, tumor_result.avg, lym_result.avg,
                     stroma_result.avg, all_result.avg, test_results,
                     conf_matrix, '{:s}/test_result.txt'.format(save_dir))
Пример #10
0
def val(img_dir, label_dir, model, transform, opt, tb_writer, epoch):
    model.eval()
    img_names = os.listdir(img_dir)
    metric_names = ['acc', 'p_F1', 'p_recall', 'p_precision', 'dice', 'aji']
    val_results = dict()
    all_results = utils.AverageMeter(len(metric_names))

    plot_num = 10  #len(img_names)
    for img_name in img_names:
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
        gt = misc.imread(label_path)

        input = transform((img, ))[0].unsqueeze(0)

        prob_maps = get_probmaps(input, model, opt)
        pred = np.argmax(prob_maps, axis=0)

        pred_labeled = measure.label(pred)
        pred_labeled = morph.remove_small_objects(pred_labeled,
                                                  opt.post['min_area'])
        pred_labeled = ndi_morph.binary_fill_holes(pred_labeled > 0)
        pred_labeled = measure.label(pred_labeled)

        metrics = compute_metrics(pred_labeled, gt, metric_names)

        if plot_num > 0:

            unNorm = get_transforms({
                'unnormalize':
                np.load('{:s}/mean_std.npy'.format(opt.train['data_dir']))
            })
            img_tensor = unNorm(input.squeeze(0))
            img_np = img_tensor.permute(1, 2, 0).numpy()
            font = ImageFont.truetype(
                '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', 42)
            metrics_text = Image.new("RGB", (512, 512), (255, 255, 255))
            draw = ImageDraw.Draw(metrics_text)
            draw.text(
                (32, 128),
                'Acc: {:.4f}\nF1: {:.4f}\nRecall: {:.4f}\nPrecision: {:.4f}\nDice: {:.4f}\nAJI: {:.4f}'
                .format(metrics['acc'], metrics['p_F1'], metrics['p_recall'],
                        metrics['p_precision'], metrics['dice'],
                        metrics['aji']),
                fill='rgb(0,0,0)',
                font=font)
            #tb_writer.add_scalars('{:s}'.format(name), 'Acc: {:.4f}\nF1: {:.4f}\nRecall: {:.4f}\nPrecision: {:.4f}\nDice: {:.4f}\nAJI: {:.4f}'.format(metrics['acc'], metrics['p_F1'], metrics['p_recall'], metrics['p_precision'], metrics['dice'], metrics['aji']), epoch)
            metrics_text = metrics_text.resize((ori_w, ori_h), Image.ANTIALIAS)
            trans_to_tensor = transforms.Compose([
                transforms.ToTensor(),
            ])
            text_tensor = trans_to_tensor(metrics_text).float()
            colored_gt = np.zeros((ori_h, ori_w, 3))
            colored_pred = np.zeros((ori_h, ori_w, 3))
            img_w_colored_gt = img_np.copy()
            img_w_colored_pred = img_np.copy()
            alpha = 0.5
            for k in range(1, gt.max() + 1):
                colored_gt[gt == k, :] = np.array(
                    utils.get_random_color(seed=k))
                img_w_colored_gt[gt == k, :] = img_w_colored_gt[gt == k, :] * (
                    1 - alpha) + colored_gt[gt == k, :] * alpha
            for k in range(1, pred_labeled.max() + 1):
                colored_pred[pred_labeled == k, :] = np.array(
                    utils.get_random_color(seed=k))
                img_w_colored_pred[
                    pred_labeled ==
                    k, :] = img_w_colored_pred[pred_labeled == k, :] * (
                        1 - alpha) + colored_pred[pred_labeled == k, :] * alpha

            gt_tensor = torch.from_numpy(colored_gt).permute(2, 0, 1).float()
            pred_tensor = torch.from_numpy(colored_pred).permute(2, 0,
                                                                 1).float()
            img_w_gt_tensor = torch.from_numpy(img_w_colored_gt).permute(
                2, 0, 1).float()
            img_w_pred_tensor = torch.from_numpy(img_w_colored_pred).permute(
                2, 0, 1).float()
            tb_writer.add_image(
                '{:s}'.format(name),
                make_grid([
                    img_tensor, img_w_gt_tensor, img_w_pred_tensor,
                    text_tensor, gt_tensor, pred_tensor
                ],
                          nrow=3,
                          padding=10,
                          pad_value=1), epoch)
            plot_num -= 1

        # update the average result
        all_results.update([
            metrics['acc'], metrics['p_F1'], metrics['p_recall'],
            metrics['p_precision'], metrics['dice'], metrics['aji']
        ])
    logger.info('\t=> Val Avg: Acc {r[0]:.4f}'
                '\tF1 {r[1]:.4f}'
                '\tRecall {r[2]:.4f}'
                '\tPrecision {r[3]:.4f}'
                '\tDice {r[4]:.4f}'
                '\tAJI {r[5]:.4f}'.format(r=all_results.avg))

    return all_results.avg