Example #1
0
    def clusterSegmentationOutput(self, series_uid,  ct, clean_ary):
        noduleLabel_ary, nodule_count = measure.label(clean_ary)
        centerIrc_list = measure.center_of_mass(
            ct.ary + 1001,
            labels=noduleLabel_ary,
            index=list(range(1, nodule_count+1)),
        )

        # n = 1298
        # log.debug([
        #     (noduleLabel_ary == n).sum(),
        #     np.where(noduleLabel_ary == n),
        #
        #     ct.ary[noduleLabel_ary == n].sum(),
        #     (ct.ary + 1000)[noduleLabel_ary == n].sum(),
        # ])

        if nodule_count < 2:
            centerIrc_list = [centerIrc_list]

        noduleInfo_list = []
        for i, center_irc in enumerate(centerIrc_list):
            center_xyz = irc2xyz(
                center_irc,
                ct.origin_xyz,
                ct.vxSize_xyz,
                ct.direction_tup,
            )
            assert np.all(np.isfinite(center_irc)), repr(['irc', center_irc, i, nodule_count])
            assert np.all(np.isfinite(center_xyz)), repr(['xyz', center_xyz])
            noduleInfo_tup = \
                NoduleInfoTuple(False, 0.0, series_uid, center_xyz)
            noduleInfo_list.append(noduleInfo_tup)

        return noduleInfo_list
Example #2
0
    def clusterSegmentationOutput(self, series_uid, ct, clean_g):
        clean_a = clean_g.cpu().numpy()
        candidateLabel_a, candidate_count = measure.label(clean_a)
        centerIrc_list = measure.center_of_mass(
            ct.hu_a.clip(-1000, 1000) + 1001,
            labels=candidateLabel_a,
            index=list(range(1, candidate_count + 1)),
        )

        candidateInfo_list = []
        for i, center_irc in enumerate(centerIrc_list):
            assert np.isfinite(center_irc).all(), repr([
                series_uid, i, candidate_count,
                (ct.hu_a[candidateLabel_a == i + 1]).sum(), center_irc
            ])
            center_xyz = irc2xyz(
                center_irc,
                ct.origin_xyz,
                ct.vxSize_xyz,
                ct.direction_a,
            )
            diameter_mm = 0.0
            # pixel_count = (candidateLabel_a == i+1).sum()
            # area_mm2 = pixel_count * ct.vxSize_xyz[0] * ct.vxSize_xyz[1]
            # diameter_mm = 2 * (area_mm2 / math.pi) ** 0.5

            candidateInfo_tup = \
                CandidateInfoTuple(None, None, None, diameter_mm, series_uid, center_xyz)
            candidateInfo_list.append(candidateInfo_tup)

        return candidateInfo_list, centerIrc_list, candidateLabel_a
Example #3
0
    def classifyCandidates(self, ct, candidateInfo_list):
        cls_dl = self.initClassificationDl(candidateInfo_list)
        classifications_list = []
        for batch_ndx, batch_tup in enumerate(cls_dl):
            input_t, _, _, series_list, center_list = batch_tup

            input_g = input_t.to(self.device)
            with torch.no_grad():
                _, probability_nodule_g = self.cls_model(input_g)
                if self.malignancy_model is not None:
                    _, probability_mal_g = self.malignancy_model(input_g)
                else:
                    probability_mal_g = torch.zeros_like(probability_nodule_g)

            zip_iter = zip(
                center_list,
                probability_nodule_g[:, 1].tolist(),
                probability_mal_g[:, 1].tolist(),
            )
            for center_irc, prob_nodule, prob_mal in zip_iter:
                center_xyz = irc2xyz(
                    center_irc,
                    direction_a=ct.direction_a,
                    origin_xyz=ct.origin_xyz,
                    vxSize_xyz=ct.vxSize_xyz,
                )
                cls_tup = (prob_nodule, prob_mal, center_xyz, center_irc)
                classifications_list.append(cls_tup)
        return classifications_list
    def clusterSegmentationOutput(self, series_uid, ct, clean_a):
        # Assign a different label to each group sconnected to
        # the others
        noduleLabel_a, nodule_count = measure.label(clean_a)
        centerIrc_list = measure.center_of_mass(
            ct.hu_a + 1001,
            labels=noduleLabel_a,
            # This part is probably redundant
            index=list(range(1, nodule_count + 1))
        )

        noduleInfo_list = []
        for i, center_irc in enumerate(centerIrc_list):
            center_xyz = irc2xyz(
                center_irc,
                ct.origin_xyz,
                ct.vxSize_xyz,
                ct.direction_tup
            )
            assert np.all(np.isfinite(center_irc)), \
                repr(['irc', center_irc, i, nodule_count])
            assert np.all(np.isfinite(center_xyz), \
                repr(['xyz', center_xyz]))
            noduleInfo_tup = \
                NoduleInfoTuple(False, 0.0, series_uid, center_xyz)
            noduleInfo_list.append(noduleInfo_tup)

        return noduleInfo_list
Example #5
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        if self.cli_args.series_uid:
            series_list = [self.cli_args.series_uid]
        else:
            series_list = sorted(
                set(noduleInfo_tup.series_uid
                    for noduleInfo_tup in getNoduleInfoList()))

        with torch.no_grad():
            series_iter = enumerateWithEstimate(
                series_list,
                "Series",
            )
            for series_ndx, series_uid in series_iter:
                seg_dl = self.initSegmentationDl(series_uid)
                ct = getCt(series_uid)

                output_ary = np.zeros_like(ct.ary, dtype=np.float32)

                # testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
                batch_iter = enumerateWithEstimate(
                    seg_dl,
                    "Seg " + series_uid,
                    start_ndx=seg_dl.num_workers,
                )
                for batch_ndx, batch_tup in batch_iter:
                    # self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
                    input_tensor, label_tensor, _series_list, ndx_list = batch_tup

                    input_devtensor = input_tensor.to(self.device)

                    prediction_devtensor = self.seg_model(input_devtensor)

                    for i, sample_ndx in enumerate(ndx_list):
                        output_ary[sample_ndx] = prediction_devtensor[
                            i].detatch().cpu().numpy()

                irc = (output_ary > 0.5).nonzero()
                xyz = irc2xyz(irc, ct.origin_xyz, ct.vxSize_xyz,
                              ct.direction_tup)

                print(irc, xyz)
Example #6
0
    def groupSegmentationOutput(self, series_uid, ct, clean_a):
        candidateLabel_a, candidate_count = measurements.label(clean_a)
        centerIrc_list = measurements.center_of_mass(
            ct.hu_a.clip(-1000, 1000) + 1001,
            labels=candidateLabel_a,
            index=np.arange(1, candidate_count + 1),
        )

        candidateInfo_list = []
        for i, center_irc in enumerate(centerIrc_list):
            center_xyz = irc2xyz(
                center_irc,
                ct.origin_xyz,
                ct.vxSize_xyz,
                ct.direction_a,
            )
            assert np.all(np.isfinite(center_irc)), repr(
                ['irc', center_irc, i, candidate_count])
            assert np.all(np.isfinite(center_xyz)), repr(['xyz', center_xyz])
            candidateInfo_tup = \
                CandidateInfoTuple(False, False, False, 0.0, series_uid, center_xyz)
            candidateInfo_list.append(candidateInfo_tup)

        return candidateInfo_list
Example #7
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        val_ds = LunaDataset(
            val_stride=10,
            isValSet_bool=True,
        )
        val_set = set(candidateInfo_tup.series_uid
                      for candidateInfo_tup in val_ds.candidateInfo_list)
        positive_set = set(candidateInfo_tup.series_uid
                           for candidateInfo_tup in getCandidateInfoList()
                           if candidateInfo_tup.isNodule_bool)

        if self.cli_args.series_uid:
            series_set = set(self.cli_args.series_uid.split(','))
        else:
            series_set = set(candidateInfo_tup.series_uid
                             for candidateInfo_tup in getCandidateInfoList())

        train_list = sorted(series_set -
                            val_set) if self.cli_args.include_train else []
        val_list = sorted(series_set & val_set)

        total_tp = total_tn = total_fp = total_fn = 0
        total_missed_pos = 0
        missed_pos_dist_list = []
        missed_pos_cit_list = []
        candidateInfo_dict = getCandidateInfoDict()
        # series2results_dict = {}
        # seg_candidateInfo_list = []
        series_iter = enumerateWithEstimate(
            val_list + train_list,
            "Series",
        )
        for _series_ndx, series_uid in series_iter:
            ct, _output_g, _mask_g, clean_g = self.segmentCt(series_uid)

            seg_candidateInfo_list, _seg_centerIrc_list, _ = self.clusterSegmentationOutput(
                series_uid,
                ct,
                clean_g,
            )
            if not seg_candidateInfo_list:
                continue

            cls_dl = self.initClassificationDl(seg_candidateInfo_list)
            results_list = []

            # batch_iter = enumerateWithEstimate(
            #     cls_dl,
            #     "Cls all",
            #     start_ndx=cls_dl.num_workers,
            # )
            # for batch_ndx, batch_tup in batch_iter:
            for batch_ndx, batch_tup in enumerate(cls_dl):
                input_t, label_t, index_t, series_list, center_t = batch_tup

                input_g = input_t.to(self.device)
                with torch.no_grad():
                    _logits_g, probability_g = self.cls_model(input_g)
                probability_t = probability_g.to('cpu')
                # probability_t = torch.tensor([[0, 1]] * input_t.shape[0], dtype=torch.float32)

                for i, _series_uid in enumerate(series_list):
                    assert series_uid == _series_uid, repr([
                        batch_ndx, i, series_uid, _series_uid,
                        seg_candidateInfo_list
                    ])
                    results_list.append((center_t[i], probability_t[i,
                                                                    0].item()))

            # This part is all about matching up annotations with our segmentation results
            tp = tn = fp = fn = 0
            missed_pos = 0
            ct = getCt(series_uid)
            candidateInfo_list = candidateInfo_dict[series_uid]
            candidateInfo_list = [
                cit for cit in candidateInfo_list if cit.isNodule_bool
            ]

            found_cit_list = [None] * len(results_list)

            for candidateInfo_tup in candidateInfo_list:
                min_dist = (999, None)

                for result_ndx, (
                        result_center_irc_t,
                        nodule_probability_t) in enumerate(results_list):
                    result_center_xyz = irc2xyz(result_center_irc_t,
                                                ct.origin_xyz, ct.vxSize_xyz,
                                                ct.direction_a)
                    delta_xyz_t = torch.tensor(
                        result_center_xyz) - torch.tensor(
                            candidateInfo_tup.center_xyz)
                    distance_t = (delta_xyz_t**2).sum().sqrt()

                    min_dist = min(min_dist, (distance_t, result_ndx))

                distance_cutoff = max(10, candidateInfo_tup.diameter_mm / 2)
                if min_dist[0] < distance_cutoff:
                    found_dist, result_ndx = min_dist
                    nodule_probability_t = results_list[result_ndx][1]

                    assert candidateInfo_tup.isNodule_bool

                    if nodule_probability_t > 0.5:
                        tp += 1
                    else:
                        fn += 1

                    found_cit_list[result_ndx] = candidateInfo_tup

                else:
                    log.warning(
                        "!!! Missed positive {}; {} min dist !!!".format(
                            candidateInfo_tup, min_dist))
                    missed_pos += 1
                    missed_pos_dist_list.append(float(min_dist[0]))
                    missed_pos_cit_list.append(candidateInfo_tup)

            # # TODO remove
            # acceptable_set = {
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.100225287222365663678666836860',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.102681962408431413578140925249',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.195557219224169985110295082004',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.216252660192313507027754194207',
            #     # '1.3.6.1.4.1.14519.5.2.1.6279.6001.229096941293122177107846044795',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.229096941293122177107846044795',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.299806338046301317870803017534',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.395623571499047043765181005112',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.487745546557477250336016826588',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.970428941353693253759289796610',
            # }
            # if missed_pos > 0 and series_uid not in acceptable_set:
            #     log.info("Unacceptable series_uid: " + series_uid)
            #     break
            #
            # if total_missed_pos > 10:
            #     break
            #
            #
            # for result_ndx, (result_center_irc_t, nodule_probability_t) in enumerate(results_list):
            #     if found_cit_list[result_ndx] is None:
            #         if nodule_probability_t > 0.5:
            #             fp += 1
            #         else:
            #             tn += 1

            log.info("{}: {} missed pos, {} fn, {} fp, {} tp, {} tn".format(
                series_uid, missed_pos, fn, fp, tp, tn))
            total_tp += tp
            total_tn += tn
            total_fp += fp
            total_fn += fn
            total_missed_pos += missed_pos

        with open(self.cli_args.segmentation_path, 'rb') as f:
            log.info(self.cli_args.segmentation_path)
            log.info(hashlib.sha1(f.read()).hexdigest())
        with open(self.cli_args.classification_path, 'rb') as f:
            log.info(self.cli_args.classification_path)
            log.info(hashlib.sha1(f.read()).hexdigest())
        log.info("{}: {} missed pos, {} fn, {} fp, {} tp, {} tn".format(
            'total', total_missed_pos, total_fn, total_fp, total_tp, total_tn))
        # missed_pos_dist_list.sort()
        # log.info("missed_pos_dist_list {}".format(missed_pos_dist_list))
        for cit, dist in zip(missed_pos_cit_list, missed_pos_dist_list):
            log.info("    Missed by {}: {}".format(dist, cit))