Ejemplo n.º 1
0
    def validate(self, batch):
        with torch.set_grad_enabled(False):
            self.model.eval()

            images = batch['image']
            labels = batch['label']
            if self.use_cuda:
                images = images.cuda()
                labels = labels.cuda()
            logits = self.model(images)
            probas = F.softmax(logits, dim=1)

            # enhance output by setting a threshold
            for i in range(1, probas.shape[1]):
                probas[:, i,
                       ...] += (probas[:, i, ...] >= self.threshold).float()

            # convert probabilities into one-hot: [B, C, ...]
            max_idx = torch.argmax(probas, 1, keepdim=True)
            one_hot = torch.zeros(probas.shape)
            if self.use_cuda:
                one_hot = one_hot.cuda()
            one_hot.scatter_(1, max_idx, 1)

            match, total = match_up(one_hot,
                                    labels,
                                    needs_softmax=False,
                                    batch_wise=True)
            result = {'match': match, 'total': total}
            if self.save_output:
                result['output'] = one_hot
            return result
Ejemplo n.º 2
0
    def infer(self, data, include_prediction=False, compute_match=False):
        # XXX: assume mode is adv
        mode = 'adv'

        with torch.set_grad_enabled(False):
            self.models['seg'].eval()
            data.update(self.models['seg'](data))

            # post-process with softmax
            probas = torch.softmax(data['prediction'], dim=1)

            # dis produce confidence_map
            self.models['dis'].eval()
            dis_inputs = {'label': probas}
            if self.dis_inlcude_image:
                dis_inputs.update({'image': data['image']})
            data.update(self.models['dis'](dis_inputs))

            # evaluate the performance of seg
            results = {}
            loss = None
            accu = None
            for key in self.training_rules[mode]:
                result = self.meters[key](data)
                if loss is None:
                    loss = result.pop('loss')
                else:
                    loss = loss + result.pop('loss')

                # tackle with the duplicated accus
                if 'accu' in result:
                    if accu is None:
                        accu = result.pop('accu')
                    else:
                        results.update(
                            {'%s_accu' % key: torch.mean(result.pop('accu'))})
                results.update(result)
            results.update({'loss': loss})

            # NOTE: improve me
            assert accu is not None
            results.update({'accu': accu})

            # evaluate the performance of dis
            results.update(self._dis_run(data, training=False))

            if include_prediction:
                results.update({'prediction': probas})

            if compute_match:
                match, total = match_up(
                    data['prediction'],
                    data['label'],
                    needs_softmax=True,
                    batch_wise=True,
                    threshold=-1,
                )
                results.update({'match': match, 'total': total})

        return results
Ejemplo n.º 3
0
 def _compute_match(self, data, results):
     with torch.set_grad_enabled(False):
         match, total = match_up(
             data['prediction'],
             data['label'],
             needs_softmax=True,
             batch_wise=True,
             threshold=-1,
         )
         results.update({'match': match, 'total': total})
Ejemplo n.º 4
0
    def run(self, stage):
        stage_config = self.config['stage'][stage]

        # build data flow from the given data generator
        # single data flow
        if isinstance(stage_config['generator'], str):
            data_gen = self.generators[stage_config['generator']]
            class_names = data_gen.struct['DL'].ROIs
            n_steps = len(data_gen)
            gen_tags = None

        # multiple data flows
        elif isinstance(stage_config['generator'], dict):
            gens = [
                self.generators[cfg]
                for cfg in stage_config['generator'].values()
            ]
            data_gen = zip(*gens)
            class_names = gens[0].struct['DL'].ROIs
            n_steps = min([len(g) for g in gens])
            gen_tags = list(stage_config['generator'].keys())

            # the forward config should match the multiple data flows
            assert isinstance(stage_config['forward'], dict)
            assert gen_tags == list(stage_config['forward'].keys())

        else:
            raise TypeError('generator of type %s is not supported.' %
                            type(stage_config['generator']))

        progress_bar = tqdm(data_gen,
                            total=n_steps,
                            ncols=get_tty_columns(),
                            dynamic_ncols=True,
                            desc='[%s] loss: %.5f, accu: %.5f' %
                            (stage, 0.0, 0.0))

        if stage not in self.step:
            self.step[stage] = 1

        # toggle trainable parameters of each module
        need_backward = False
        for key, toggle in stage_config['toggle'].items():
            self.handlers[key].model.train(toggle)
            for param in self.handlers[key].model.parameters():
                param.requires_grad = toggle
            if toggle:
                need_backward = True

        result_list = []
        need_revert = 'revert' in stage_config and stage_config['revert']
        for batch in progress_bar:

            self.step[stage] += 1

            # single data flow
            if gen_tags is None:
                assert isinstance(batch, dict)

                # insert batch to data
                data = dict()
                for key in batch:
                    if torch.cuda.device_count() >= 1:
                        data[key] = batch[key].cuda()
                    else:
                        data[key] = batch[key]

                # forward
                for key in stage_config['forward']:
                    data.update(self.handlers[key].model(data))

            # multiple data flows
            else:
                assert isinstance(batch, tuple)
                data = dict()
                for (tag, tag_batch) in zip(gen_tags, batch):
                    tag_data = dict()

                    # insert batch to data
                    for key in tag_batch:
                        if torch.cuda.device_count() >= 1:
                            tag_data[key] = tag_batch[key].cuda()
                        else:
                            tag_data[key] = tag_batch[key]

                    # forward
                    for key in stage_config['forward'][tag]:
                        tag_data.update(self.handlers[key].model(tag_data))

                    # insert tag data back to the data
                    data.update({
                        '%s_%s' % (key, tag): tag_data[key]
                        for key in tag_data
                    })

            # compute loss and accuracy
            results = self.metrics[stage_config['metric']](data)

            # backpropagation
            if need_backward:
                results['loss'].backward()
                for key, toggle in stage_config['toggle'].items():
                    if toggle:
                        self.optims[key].step()
                        self.optims[key].zero_grad()

            # compute match for dice score of each case after reversion
            if need_revert:
                assert 'prediction' in data, list(data.keys())
                assert 'label' in data, list(data.keys())
                with torch.set_grad_enabled(False):
                    match, total = match_up(
                        data['prediction'],
                        data['label'],
                        needs_softmax=True,
                        batch_wise=True,
                        threshold=-1,
                    )
                    results.update({'match': match, 'total': total})

            # detach all results, move to CPU, and convert to numpy
            for key in results:
                results[key] = results[key].detach().cpu().numpy()

            # average accuracy if multi-dim
            assert 'accu' in results
            if results['accu'].ndim == 0:
                step_accu = math.nan if results[
                    'accu'] == math.nan else results['accu']
            else:
                assert results['accu'].ndim == 1
                empty = True
                for acc in results['accu']:
                    if not np.isnan(acc):
                        empty = False
                        break
                step_accu = math.nan if empty else np.nanmean(results['accu'])

            assert 'loss' in results
            progress_bar.set_description('[%s] loss: %.5f, accu: %.5f' %
                                         (stage, results['loss'], step_accu))

            if self.logger is not None:
                self.logger.add_scalar('%s/step/loss' % stage, results['loss'],
                                       self.step[stage])
                self.logger.add_scalar(
                    '%s/step/accu' % stage,
                    -1 if math.isnan(step_accu) else step_accu,
                    self.step[stage])

            result_list.append(results)

        summary = dict()
        if need_revert:
            reverter = Reverter(data_gen)
            result_collection_blacklist = reverter.revertible

            scores = dict()
            progress_bar = tqdm(reverter.on_batches(result_list),
                                total=len(reverter.data_list),
                                dynamic_ncols=True,
                                ncols=get_tty_columns(),
                                desc='[Data index]')
            for reverted in progress_bar:
                data_idx = reverted['idx']
                scores[data_idx] = reverted['score']
                info = '[%s] mean score: %.3f' % (
                    data_idx, np.mean(list(scores[data_idx].values())))
                progress_bar.set_description(info)

            # summerize score of each class over data indices
            cls_scores = {
                cls: np.mean([scores[data_idx][cls] for data_idx in scores])
                for cls in class_names
            }
            cls_scores.update(
                {'mean': np.mean([cls_scores[cls] for cls in class_names])})

            summary['scores'] = scores
            summary['cls_scores'] = cls_scores

        else:
            result_collection_blacklist = []

        # collect results except those revertible ones, e.g., accu, loss
        summary.update({
            key: np.nanmean(np.vstack([result[key] for result in result_list]),
                            axis=0)
            for key in result_list[0].keys()
            if key not in result_collection_blacklist
        })

        # process 1D array accu to dictionary of each class score
        if len(summary['accu']) > 1:
            assert len(summary['accu']) == len(class_names), (len(
                summary['accu']), len(class_names))
            summary['cls_accu'] = {
                cls: summary['accu'][i]
                for (i, cls) in enumerate(class_names)
            }
            summary['accu'] = summary['accu'].mean()

        # print summary info
        print('Average: ' + ', '.join([
            '%s: %.3f' % (key, val)
            for (key, val) in summary.items() if not isinstance(val, dict)
        ]))

        if 'cls_scores' in summary:
            print('Class score: ' + ', '.join([
                '%s: %.3f' % (key, val)
                for (key, val) in summary['cls_scores'].items()
            ]))

        return summary