Пример #1
0
    def doTraining(self, epoch_ndx, train_dl):
        self.model.train()
        trainingMetrics_tensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset))
        # train_dl.dataset.shuffleSamples()
        batch_iter = enumerateWithEstimate(
            train_dl,
            "E{} Training".format(epoch_ndx),
            start_ndx=train_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            self.optimizer.zero_grad()

            if self.cli_args.segmentation:
                loss_var = self.computeSegmentationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
            else:
                loss_var = self.computeClassificationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)

            if loss_var is not None:
                loss_var.backward()
                self.optimizer.step()
            del loss_var

        self.totalTrainingSamples_count += trainingMetrics_tensor.size(1)

        return trainingMetrics_tensor
Пример #2
0
    def doTraining(self, epoch_ndx, train_dl):
        self.model.train()
        train_dl.dataset.shuffleSamples()
        trnMetrics_g = torch.zeros(
            METRICS_SIZE,
            len(train_dl.dataset),
            device=self.device,
        )

        batch_iter = enumerateWithEstimate(
            train_dl,
            "E{} Training".format(epoch_ndx),
            start_ndx=train_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            self.optimizer.zero_grad()

            loss_var = self.computeBatchLoss(
                batch_ndx,
                batch_tup,
                train_dl.batch_size,
                trnMetrics_g,
                augment=True
            )

            loss_var.backward()
            self.optimizer.step()

        self.totalTrainingSamples_count += len(train_dl.dataset)

        return trnMetrics_g.to('cpu')
Пример #3
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        val_ds = LunaDataset(
            val_stride=10,
            isValSet_bool=True,
        )
        val_set = set(candidateInfo_tup.series_uid
                      for candidateInfo_tup in val_ds.candidateInfo_list)
        positive_set = set(candidateInfo_tup.series_uid
                           for candidateInfo_tup in getCandidateInfoList()
                           if candidateInfo_tup.isNodule_bool)

        if self.cli_args.series_uid:
            series_set = set(self.cli_args.series_uid.split(','))
        else:
            series_set = set(candidateInfo_tup.series_uid
                             for candidateInfo_tup in getCandidateInfoList())

        if self.cli_args.include_train:
            train_list = sorted(series_set - val_set)
        else:
            train_list = []
        val_list = sorted(series_set & val_set)

        candidateInfo_dict = getCandidateInfoDict()
        series_iter = enumerateWithEstimate(
            val_list + train_list,
            "Series",
        )
        all_confusion = np.zeros((3, 4), dtype=np.int)
        for _, series_uid in series_iter:
            ct = getCt(series_uid)
            mask_a = self.segmentCt(ct, series_uid)

            candidateInfo_list = self.groupSegmentationOutput(
                series_uid, ct, mask_a)
            classifications_list = self.classifyCandidates(
                ct, candidateInfo_list)

            if not self.cli_args.run_validation:
                print(f"found nodule candidates in {series_uid}:")
                for prob, prob_mal, center_xyz, center_irc in classifications_list:
                    if prob > 0.5:
                        s = f"nodule prob {prob:.3f}, "
                        if self.malignancy_model:
                            s += f"malignancy prob {prob_mal:.3f}, "
                        s += f"center xyz {center_xyz}"
                        print(s)

            if series_uid in candidateInfo_dict:
                one_confusion = match_and_score(classifications_list,
                                                candidateInfo_dict[series_uid])
                all_confusion += one_confusion
                print_confusion(series_uid, one_confusion,
                                self.malignancy_model is not None)

        print_confusion("Total", all_confusion, self.malignancy_model
                        is not None)
Пример #4
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        if self.cli_args.series_uid:
            series_list = [self.cli_args.series_uid]
        else:
            series_list = sorted(
                set(noduleInfo_tup.series_uid
                    for noduleInfo_tup in getNoduleInfoList()))

        with torch.no_grad():
            series_iter = enumerateWithEstimate(
                series_list,
                "Series",
            )
            for series_ndx, series_uid in series_iter:
                seg_dl = self.initSegmentationDl(series_uid)
                ct = getCt(series_uid)

                output_ary = np.zeros_like(ct.ary, dtype=np.float32)

                # testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
                batch_iter = enumerateWithEstimate(
                    seg_dl,
                    "Seg " + series_uid,
                    start_ndx=seg_dl.num_workers,
                )
                for batch_ndx, batch_tup in batch_iter:
                    # self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
                    input_tensor, label_tensor, _series_list, ndx_list = batch_tup

                    input_devtensor = input_tensor.to(self.device)

                    prediction_devtensor = self.seg_model(input_devtensor)

                    for i, sample_ndx in enumerate(ndx_list):
                        output_ary[sample_ndx] = prediction_devtensor[
                            i].detatch().cpu().numpy()

                irc = (output_ary > 0.5).nonzero()
                xyz = irc2xyz(irc, ct.origin_xyz, ct.vxSize_xyz,
                              ct.direction_tup)

                print(irc, xyz)
Пример #5
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        self.prep_dl = DataLoader(LunaDataset(sortby_str='series_uid'),
                                  batch_size=self.cli_args.batch_size,
                                  num_workers=self.cli_args.num_workers)

        batch_iter = enumerateWithEstimate(self.prep_dl,
                                           "Stuffing cache",
                                           start_ndx=self.prep_dl.num_workers)

        for _ in batch_iter:
            pass
    def doValidation(self, epoch_ndx, val_dl):
        with torch.no_grad():
            valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset), device=self.device)
            self.segmentation_model.eval()

            batch_iter = enumerateWithEstimate(
                val_dl,
                "E{} Validation ".format(epoch_ndx),
                start_ndx=val_dl.num_workers,
            )
            for batch_ndx, batch_tup in batch_iter:
                self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)

        return valMetrics_g.to('cpu')
Пример #7
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        self.prep_dl = DataLoader(
            LunaPrepcacheDataset(),
            batch_size=self.cli_args.batch_size,
            num_workers=self.cli_args.num_workers,
        )

        batch_iter = enumerateWithEstimate(
            self.prep_dl,
            "Stuffing cache",
            start_ndx=self.prep_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            pass
Пример #8
0
    def doTesting(self, epoch_ndx, test_dl):
        with torch.no_grad():
            self.model.eval()
            testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
            batch_iter = enumerateWithEstimate(
                test_dl,
                "E{} Testing ".format(epoch_ndx),
                start_ndx=test_dl.num_workers,
            )
            for batch_ndx, batch_tup in batch_iter:
                if self.cli_args.segmentation:
                    self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
                else:
                    self.computeClassificationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)

        return testingMetrics_tensor
Пример #9
0
    def doTesting(self, epoch_ndx, test_dl):
        with torch.no_grad():
            self.model.eval()
            testingMetrics_devtensor = torch.zeros(METRICS_SIZE,
                                                   len(test_dl.dataset)).to(
                                                       self.device)
            batch_iter = enumerateWithEstimate(
                test_dl,
                "E{} Testing ".format(epoch_ndx),
                start_ndx=test_dl.num_workers,
            )
            for batch_ndx, batch_tup in batch_iter:
                self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size,
                                      testingMetrics_devtensor)

        return testingMetrics_devtensor.to('cpu')
Пример #10
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        self.prep_dl = DataLoader(
            LunaClassificationDataset(sortby_str='series_uid', ),
            batch_size=self.cli_args.batch_size,
            num_workers=self.cli_args.num_workers,
        )

        batch_iter = enumerateWithEstimate(
            self.prep_dl,
            "Stuffing cache",
            start_ndx=self.prep_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            _nodule_tensor, _malignant_tensor, series_list, _center_list = batch_tup
            for series_uid in sorted(set(series_list)):
                getCtSize(series_uid)
    def doTest(self, epoch_ndx, test_dl):
        with torch.no_grad():
            testMetrics_g = torch.zeros(
                METRICS_SIZE, len(test_dl.dataset), device=self.device
            )
            # self.segmentation_model = torch.load()
            self.segmentation_model.eval()

            batch_iter = enumerateWithEstimate(
                test_dl,
                "E{} Validation ".format(epoch_ndx),
                start_ndx=test_dl.num_workers,
            )
            for batch_ndx, batch_tup in batch_iter:
                self.computeBatchLoss(
                    batch_ndx, batch_tup, test_dl.batch_size, testMetrics_g
                )

        return testMetrics_g.to("cpu")
Пример #12
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        self.prep_dl = DataLoader(
            LunaScreenCtDataset(),
            batch_size=self.cli_args.batch_size,
            num_workers=self.cli_args.num_workers,
        )

        series2ratio_dict = {}

        batch_iter = enumerateWithEstimate(
            self.prep_dl,
            "Screening CTs",
            start_ndx=self.prep_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            series_list, ratio_list = batch_tup
            for series_uid, ratio_float in zip(series_list, ratio_list):
                series2ratio_dict[series_uid] = ratio_float
            # break

        prhist(list(series2ratio_dict.values()))
Пример #13
0
    def doTraining(self, epoch_ndx, train_dl):
        self.model.train()
        trainingMetrics_devtensor = torch.zeros(METRICS_SIZE,
                                                len(train_dl.dataset)).to(
                                                    self.device)
        batch_iter = enumerateWithEstimate(
            train_dl,
            "E{} Training".format(epoch_ndx),
            start_ndx=train_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            self.optimizer.zero_grad()

            loss_var = self.computeBatchLoss(batch_ndx, batch_tup,
                                             train_dl.batch_size,
                                             trainingMetrics_devtensor)

            loss_var.backward()
            self.optimizer.step()
            del loss_var

        self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)

        return trainingMetrics_devtensor.to('cpu')
Пример #14
0
    def doTraining(self, epoch_ndx, train_dl):
        self.model.train()
        trnMetrics_g = torch.zeros(
            METRICS_SIZE,
            len(train_dl.dataset),
            device=self.device,
        )

        batch_iter = enumerateWithEstimate(
            train_dl,
            "E{} Training".format(epoch_ndx),
            start_ndx=train_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            self.optimizer.zero_grad()

            loss_var = self.computeBatchLoss(
                batch_ndx,
                batch_tup,
                train_dl.batch_size,
                trnMetrics_g
            )

            loss_var.backward()
            self.optimizer.step()

            # # This is for adding the model graph to TensorBoard.
            # if epoch_ndx == 1 and batch_ndx == 0:
            #     with torch.no_grad():
            #         model = LunaModel()
            #         self.trn_writer.add_graph(model, batch_tup[0], verbose=True)
            #         self.trn_writer.close()

        self.totalTrainingSamples_count += len(train_dl.dataset)

        return trnMetrics_g.to('cpu')
Пример #15
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        self.totalTrainingSamples_count = 0

        self.model = LunaModel()
        if self.use_cuda:
            if torch.cuda.device_count() > 1:
                self.model = nn.DataParallel(self.model)

            self.model = self.model.to(self.device)

        self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)

        train_dl = DataLoader(
            LunaDataset(
                test_stride=10,
                isTestSet_bool=False,
                ratio_int=int(self.cli_args.balanced),
            ),
            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )

        test_dl = DataLoader(
            LunaDataset(
                test_stride=10,
                isTestSet_bool=True,
            ),
            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )

        for epoch_ndx in range(1, self.cli_args.epochs + 1):

            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
                epoch_ndx,
                self.cli_args.epochs,
                len(train_dl),
                len(test_dl),
                self.cli_args.batch_size,
                (torch.cuda.device_count() if self.use_cuda else 1),
            ))

            # Training loop, very similar to below
            self.model.train()
            trainingMetrics_tensor = torch.zeros(3, len(train_dl.dataset), 1)
            train_dl.dataset.shuffleSamples()
            batch_iter = enumerateWithEstimate(
                train_dl,
                "E{} Training".format(epoch_ndx),
                start_ndx=train_dl.num_workers,
            )
            for batch_ndx, batch_tup in batch_iter:
                self.optimizer.zero_grad()
                loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
                loss_var.backward()
                self.optimizer.step()
                del loss_var

            # Testing loop, very similar to above, but simplified
            with torch.no_grad():
                self.model.eval()
                testingMetrics_tensor = torch.zeros(3, len(test_dl.dataset), 1)
                batch_iter = enumerateWithEstimate(
                    test_dl,
                    "E{} Testing ".format(epoch_ndx),
                    start_ndx=test_dl.num_workers,
                )
                for batch_ndx, batch_tup in batch_iter:
                    self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)

            self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)

        if hasattr(self, 'trn_writer'):
            self.trn_writer.close()
            self.tst_writer.close()
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        val_ds = LunaDataset(
            val_stride=10,
            isValSet_bool=True
        )

        val_set = set(
            noduleInfo_tup.series_uid
            for noduleInfo_tup in val_ds.noduleInfo_list
        )
        malignant_set = set(
            noduleInfo_tup.series_uid
            for noduleInfo_tup in getNoduleInfoList()
            if noduleInfo_tup.isMalignant_bool
        )

        if self.cli_args.series_uid:
            series_set = set(self.cli_args.series_uid.split(','))
        else:
            series_set = set(
                noduleInfo_tup.series_uid
                for noduleInfo_tup in getNoduleInfoList()
            )

        train_list = sorted(series_set - val_set) if \
            self.cli_args.include_train else []
        val_list = sorted(series_set & val_set)

        noduleInfo_list = []
        series_iter = enumerateWithEstimate(
            val_list + train_list,
            'Series'
        )
        for _, series_uid, in series_iter:
            ct, _, _, clean_a = self.segmentCt(series_uid)

            noduleInfo_list += self.clusterSegmentationOutput(
                series_uid,
                ct,
                clean_a
            )

        cls_dl = self.initClassificationDl(noduleInfo_list)

        series2diagnosis_dict = {}
        batch_iter = enumerateWithEstimate(
            cls_dl,
            "Cls all",
            start_ndx=cls_dl.num_workers
        )
        for batch_ndx, batch_tup in batch_iter:
            input_t, _, series_list, center_list = batch_tup

            input_g = input_t.to(self.device)
            with torch.no_grad():
                _, probability_g = self.cls_model(input_g)

            classification_list = zip(
                series_list,
                center_list,
                probability_g[:, 1].to('cpu')
            )
            for cls_tup in classification_list:
                series_uid, center_irc, probability_t = cls_tup
                probability_float = probability_t.item()

                this_tup = (probability_float, tuple(center_irc))
                current_tup = series2diagnosis_dict.get(series_uid,
                                                        this_tup)
                try:
                    assert np.all(np.isfinite(tuple(center_irc)))
                    if this_tup > current_tup:
                        log.debug([series_uid, this_tup])
                    # This part is to cover the eventuality that
                    # the same series_uid is repeted multiple
                    # times
                    series2diagnosis_dict[series_uid] = \
                        max(this_tup, current_tup)
                except:
                    log.debug([(type(x), x) for x in this_tup] +
                              [(type(x), x) for x in current_tup])
                    raise

        log.info('Training set:')
        self.logResults('Training', train_list, series2diagnosis_dict,
                        malignant_set)

        log.info('Validation set:')
        self.logResults('Validation', val_list, series2diagnosis_dict,
                        malignant_set)
Пример #17
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        test_ds = LunaDataset(
            test_stride=10,
            isTestSet_bool=True,
        )
        test_set = set(
            noduleInfo_tup.series_uid
            for noduleInfo_tup in test_ds.noduleInfo_list
        )
        malignant_set = set(
            noduleInfo_tup.series_uid
            for noduleInfo_tup in getNoduleInfoList()
            if noduleInfo_tup.isMalignant_bool
        )

        if self.cli_args.series_uid:
            series_set = set(self.cli_args.series_uid.split(','))
        else:
            series_set = set(
                noduleInfo_tup.series_uid
                for noduleInfo_tup in getNoduleInfoList()
            )

        train_list = sorted(series_set - test_set) if self.cli_args.include_train else []
        test_list = sorted(series_set & test_set)


        noduleInfo_list = []
        series_iter = enumerateWithEstimate(
            test_list + train_list,
            "Series",
        )
        for _series_ndx, series_uid in series_iter:
            ct, output_ary, _mask_ary, clean_ary = self.segmentCt(series_uid)

            noduleInfo_list += self.clusterSegmentationOutput(
                series_uid,
                ct,
                clean_ary,
            )

            # if _series_ndx > 10:
            #     break


        cls_dl = self.initClassificationDl(noduleInfo_list)

        series2diagnosis_dict = {}
        batch_iter = enumerateWithEstimate(
            cls_dl,
            "Cls all",
            start_ndx=cls_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            input_tensor, _, series_list, center_list = batch_tup

            input_devtensor = input_tensor.to(self.device)
            with torch.no_grad():
                _logits_devtensor, probability_devtensor = self.cls_model(input_devtensor)

            classifications_list = zip(
                series_list,
                center_list,
                probability_devtensor[:,1].to('cpu'),
            )

            for cls_tup in classifications_list:
                series_uid, center_irc, probablity_tensor = cls_tup
                probablity_float = probablity_tensor.item()

                this_tup = (probablity_float, tuple(center_irc))
                current_tup = series2diagnosis_dict.get(series_uid, this_tup)
                try:
                    assert np.all(np.isfinite(tuple(center_irc)))
                    if this_tup > current_tup:
                        log.debug([series_uid, this_tup])
                    series2diagnosis_dict[series_uid] = max(this_tup, current_tup)
                except:
                    log.debug([(type(x), x) for x in this_tup] + [(type(x), x) for x in current_tup])
                    raise

                # self.logResults(
                #     'Testing' if isTest_bool else 'Training',
                #     [(series_uid, series2diagnosis_dict[series_uid])],
                #     malignant_set,
                # )

        log.info('Training set:')
        self.logResults('Training', train_list, series2diagnosis_dict, malignant_set)

        log.info('Testing set:')
        self.logResults('Testing', test_list, series2diagnosis_dict, malignant_set)
Пример #18
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
        self.train_dl = DataLoader(
            LunaDataset(
                test_stride=10,
                isTestSet_bool=False,
            ),
            batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
            num_workers=self.cli_args.num_workers,
            pin_memory=True,
        )
        self.test_dl = DataLoader(
            LunaDataset(
                test_stride=10,
                isTestSet_bool=True,
            ),
            batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
            num_workers=self.cli_args.num_workers,
            pin_memory=True,
        )

        self.model = LunaModel(self.cli_args.layers, 1, self.cli_args.channels)
        self.model = nn.DataParallel(self.model)
        self.model = self.model.cuda()

        self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)

        for epoch_ndx in range(1, self.cli_args.epochs + 1):
            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
                epoch_ndx,
                self.cli_args.epochs,
                len(self.train_dl),
                len(self.test_dl),
                self.cli_args.batch_size,
                torch.cuda.device_count(),
            ))

            # Training loop, very similar to below
            self.model.train()
            batch_iter = enumerateWithEstimate(
                self.train_dl,
                "E{} Training".format(epoch_ndx),
                start_ndx=self.train_dl.num_workers,
            )
            trainingMetrics_ary = np.zeros((3, len(self.train_dl.dataset)),
                                           dtype=np.float32)
            for batch_ndx, batch_tup in batch_iter:
                self.optimizer.zero_grad()
                loss_var = self.computeBatchLoss(batch_ndx, batch_tup,
                                                 self.train_dl.batch_size,
                                                 trainingMetrics_ary)
                loss_var.backward()
                self.optimizer.step()
                del loss_var

            # Testing loop, very similar to above, but simplified
            # ...
            self.model.eval()
            batch_iter = enumerateWithEstimate(
                self.test_dl,
                "E{} Testing ".format(epoch_ndx),
                start_ndx=self.test_dl.num_workers,
            )
            testingMetrics_ary = np.zeros((3, len(self.test_dl.dataset)),
                                          dtype=np.float32)
            for batch_ndx, batch_tup in batch_iter:
                self.computeBatchLoss(batch_ndx, batch_tup,
                                      self.test_dl.batch_size,
                                      testingMetrics_ary)

            self.logMetrics(epoch_ndx, trainingMetrics_ary, testingMetrics_ary)
Пример #19
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        val_ds = LunaDataset(
            val_stride=10,
            isValSet_bool=True,
        )
        val_set = set(candidateInfo_tup.series_uid
                      for candidateInfo_tup in val_ds.candidateInfo_list)
        positive_set = set(candidateInfo_tup.series_uid
                           for candidateInfo_tup in getCandidateInfoList()
                           if candidateInfo_tup.isNodule_bool)

        if self.cli_args.series_uid:
            series_set = set(self.cli_args.series_uid.split(','))
        else:
            series_set = set(candidateInfo_tup.series_uid
                             for candidateInfo_tup in getCandidateInfoList())

        train_list = sorted(series_set -
                            val_set) if self.cli_args.include_train else []
        val_list = sorted(series_set & val_set)

        total_tp = total_tn = total_fp = total_fn = 0
        total_missed_pos = 0
        missed_pos_dist_list = []
        missed_pos_cit_list = []
        candidateInfo_dict = getCandidateInfoDict()
        # series2results_dict = {}
        # seg_candidateInfo_list = []
        series_iter = enumerateWithEstimate(
            val_list + train_list,
            "Series",
        )
        for _series_ndx, series_uid in series_iter:
            ct, _output_g, _mask_g, clean_g = self.segmentCt(series_uid)

            seg_candidateInfo_list, _seg_centerIrc_list, _ = self.clusterSegmentationOutput(
                series_uid,
                ct,
                clean_g,
            )
            if not seg_candidateInfo_list:
                continue

            cls_dl = self.initClassificationDl(seg_candidateInfo_list)
            results_list = []

            # batch_iter = enumerateWithEstimate(
            #     cls_dl,
            #     "Cls all",
            #     start_ndx=cls_dl.num_workers,
            # )
            # for batch_ndx, batch_tup in batch_iter:
            for batch_ndx, batch_tup in enumerate(cls_dl):
                input_t, label_t, index_t, series_list, center_t = batch_tup

                input_g = input_t.to(self.device)
                with torch.no_grad():
                    _logits_g, probability_g = self.cls_model(input_g)
                probability_t = probability_g.to('cpu')
                # probability_t = torch.tensor([[0, 1]] * input_t.shape[0], dtype=torch.float32)

                for i, _series_uid in enumerate(series_list):
                    assert series_uid == _series_uid, repr([
                        batch_ndx, i, series_uid, _series_uid,
                        seg_candidateInfo_list
                    ])
                    results_list.append((center_t[i], probability_t[i,
                                                                    0].item()))

            # This part is all about matching up annotations with our segmentation results
            tp = tn = fp = fn = 0
            missed_pos = 0
            ct = getCt(series_uid)
            candidateInfo_list = candidateInfo_dict[series_uid]
            candidateInfo_list = [
                cit for cit in candidateInfo_list if cit.isNodule_bool
            ]

            found_cit_list = [None] * len(results_list)

            for candidateInfo_tup in candidateInfo_list:
                min_dist = (999, None)

                for result_ndx, (
                        result_center_irc_t,
                        nodule_probability_t) in enumerate(results_list):
                    result_center_xyz = irc2xyz(result_center_irc_t,
                                                ct.origin_xyz, ct.vxSize_xyz,
                                                ct.direction_a)
                    delta_xyz_t = torch.tensor(
                        result_center_xyz) - torch.tensor(
                            candidateInfo_tup.center_xyz)
                    distance_t = (delta_xyz_t**2).sum().sqrt()

                    min_dist = min(min_dist, (distance_t, result_ndx))

                distance_cutoff = max(10, candidateInfo_tup.diameter_mm / 2)
                if min_dist[0] < distance_cutoff:
                    found_dist, result_ndx = min_dist
                    nodule_probability_t = results_list[result_ndx][1]

                    assert candidateInfo_tup.isNodule_bool

                    if nodule_probability_t > 0.5:
                        tp += 1
                    else:
                        fn += 1

                    found_cit_list[result_ndx] = candidateInfo_tup

                else:
                    log.warning(
                        "!!! Missed positive {}; {} min dist !!!".format(
                            candidateInfo_tup, min_dist))
                    missed_pos += 1
                    missed_pos_dist_list.append(float(min_dist[0]))
                    missed_pos_cit_list.append(candidateInfo_tup)

            # # TODO remove
            # acceptable_set = {
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.100225287222365663678666836860',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.102681962408431413578140925249',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.195557219224169985110295082004',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.216252660192313507027754194207',
            #     # '1.3.6.1.4.1.14519.5.2.1.6279.6001.229096941293122177107846044795',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.229096941293122177107846044795',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.299806338046301317870803017534',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.395623571499047043765181005112',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.487745546557477250336016826588',
            #     '1.3.6.1.4.1.14519.5.2.1.6279.6001.970428941353693253759289796610',
            # }
            # if missed_pos > 0 and series_uid not in acceptable_set:
            #     log.info("Unacceptable series_uid: " + series_uid)
            #     break
            #
            # if total_missed_pos > 10:
            #     break
            #
            #
            # for result_ndx, (result_center_irc_t, nodule_probability_t) in enumerate(results_list):
            #     if found_cit_list[result_ndx] is None:
            #         if nodule_probability_t > 0.5:
            #             fp += 1
            #         else:
            #             tn += 1

            log.info("{}: {} missed pos, {} fn, {} fp, {} tp, {} tn".format(
                series_uid, missed_pos, fn, fp, tp, tn))
            total_tp += tp
            total_tn += tn
            total_fp += fp
            total_fn += fn
            total_missed_pos += missed_pos

        with open(self.cli_args.segmentation_path, 'rb') as f:
            log.info(self.cli_args.segmentation_path)
            log.info(hashlib.sha1(f.read()).hexdigest())
        with open(self.cli_args.classification_path, 'rb') as f:
            log.info(self.cli_args.classification_path)
            log.info(hashlib.sha1(f.read()).hexdigest())
        log.info("{}: {} missed pos, {} fn, {} fp, {} tp, {} tn".format(
            'total', total_missed_pos, total_fn, total_fp, total_tp, total_tn))
        # missed_pos_dist_list.sort()
        # log.info("missed_pos_dist_list {}".format(missed_pos_dist_list))
        for cit, dist in zip(missed_pos_cit_list, missed_pos_dist_list):
            log.info("    Missed by {}: {}".format(dist, cit))
Пример #20
0
    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
        self.train_dl = DataLoader(
            LunaDataset(
                test_stride=10,
                isTestSet_bool=False,
                balanced_bool=self.cli_args.balanced,
                scaled_bool=self.cli_args.scaled or self.cli_args.augmented,
                augmented_bool=self.cli_args.augmented,
            ),
            batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
            num_workers=self.cli_args.num_workers,
            pin_memory=True,
        )
        self.test_dl = DataLoader(
            LunaDataset(
                test_stride=10,
                isTestSet_bool=True,
                scaled_bool=self.cli_args.scaled or self.cli_args.augmented,
                # augmented_bool=self.cli_args.augmented,
            ),
            batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
            num_workers=self.cli_args.num_workers,
            pin_memory=True,
        )

        self.model = LunaModel(self.cli_args.layers, 1, self.cli_args.channels)
        self.model = nn.DataParallel(self.model)
        self.model = self.model.cuda()

        self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)

        time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
        log_dir = os.path.join('runs', self.cli_args.tb_prefix, time_str)
        self.trn_writer = SummaryWriter(log_dir=log_dir + '_train')
        self.tst_writer = SummaryWriter(log_dir=log_dir + '_test')

        for epoch_ndx in range(1, self.cli_args.epochs + 1):
            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
                epoch_ndx,
                self.cli_args.epochs,
                len(self.train_dl),
                len(self.test_dl),
                self.cli_args.batch_size,
                torch.cuda.device_count(),
            ))

            # Training loop, very similar to below
            self.model.train()
            self.train_dl.dataset.shuffleSamples()
            batch_iter = enumerateWithEstimate(
                self.train_dl,
                "E{} Training".format(epoch_ndx),
                start_ndx=self.train_dl.num_workers,
            )
            trainingMetrics_ary = np.zeros((3, len(self.train_dl.dataset)),
                                           dtype=np.float32)
            for batch_ndx, batch_tup in batch_iter:
                self.optimizer.zero_grad()
                loss_var = self.computeBatchLoss(batch_ndx, batch_tup,
                                                 self.train_dl.batch_size,
                                                 trainingMetrics_ary)
                loss_var.backward()
                self.optimizer.step()
                del loss_var

            # Testing loop, very similar to above, but simplified
            # ...
            self.model.eval()
            self.test_dl.dataset.shuffleSamples()
            batch_iter = enumerateWithEstimate(
                self.test_dl,
                "E{} Testing ".format(epoch_ndx),
                start_ndx=self.test_dl.num_workers,
            )
            testingMetrics_ary = np.zeros((3, len(self.test_dl.dataset)),
                                          dtype=np.float32)
            for batch_ndx, batch_tup in batch_iter:
                self.computeBatchLoss(batch_ndx, batch_tup,
                                      self.test_dl.batch_size,
                                      testingMetrics_ary)

            self.logMetrics(epoch_ndx, trainingMetrics_ary, testingMetrics_ary)

        self.trn_writer.close()
        self.tst_writer.close()