def validate(self, batch): with torch.set_grad_enabled(False): self.model.eval() images = batch['image'] labels = batch['label'] if self.use_cuda: images = images.cuda() labels = labels.cuda() logits = self.model(images) probas = F.softmax(logits, dim=1) # enhance output by setting a threshold for i in range(1, probas.shape[1]): probas[:, i, ...] += (probas[:, i, ...] >= self.threshold).float() # convert probabilities into one-hot: [B, C, ...] max_idx = torch.argmax(probas, 1, keepdim=True) one_hot = torch.zeros(probas.shape) if self.use_cuda: one_hot = one_hot.cuda() one_hot.scatter_(1, max_idx, 1) match, total = match_up(one_hot, labels, needs_softmax=False, batch_wise=True) result = {'match': match, 'total': total} if self.save_output: result['output'] = one_hot return result
def infer(self, data, include_prediction=False, compute_match=False): # XXX: assume mode is adv mode = 'adv' with torch.set_grad_enabled(False): self.models['seg'].eval() data.update(self.models['seg'](data)) # post-process with softmax probas = torch.softmax(data['prediction'], dim=1) # dis produce confidence_map self.models['dis'].eval() dis_inputs = {'label': probas} if self.dis_inlcude_image: dis_inputs.update({'image': data['image']}) data.update(self.models['dis'](dis_inputs)) # evaluate the performance of seg results = {} loss = None accu = None for key in self.training_rules[mode]: result = self.meters[key](data) if loss is None: loss = result.pop('loss') else: loss = loss + result.pop('loss') # tackle with the duplicated accus if 'accu' in result: if accu is None: accu = result.pop('accu') else: results.update( {'%s_accu' % key: torch.mean(result.pop('accu'))}) results.update(result) results.update({'loss': loss}) # NOTE: improve me assert accu is not None results.update({'accu': accu}) # evaluate the performance of dis results.update(self._dis_run(data, training=False)) if include_prediction: results.update({'prediction': probas}) if compute_match: match, total = match_up( data['prediction'], data['label'], needs_softmax=True, batch_wise=True, threshold=-1, ) results.update({'match': match, 'total': total}) return results
def _compute_match(self, data, results): with torch.set_grad_enabled(False): match, total = match_up( data['prediction'], data['label'], needs_softmax=True, batch_wise=True, threshold=-1, ) results.update({'match': match, 'total': total})
def run(self, stage): stage_config = self.config['stage'][stage] # build data flow from the given data generator # single data flow if isinstance(stage_config['generator'], str): data_gen = self.generators[stage_config['generator']] class_names = data_gen.struct['DL'].ROIs n_steps = len(data_gen) gen_tags = None # multiple data flows elif isinstance(stage_config['generator'], dict): gens = [ self.generators[cfg] for cfg in stage_config['generator'].values() ] data_gen = zip(*gens) class_names = gens[0].struct['DL'].ROIs n_steps = min([len(g) for g in gens]) gen_tags = list(stage_config['generator'].keys()) # the forward config should match the multiple data flows assert isinstance(stage_config['forward'], dict) assert gen_tags == list(stage_config['forward'].keys()) else: raise TypeError('generator of type %s is not supported.' % type(stage_config['generator'])) progress_bar = tqdm(data_gen, total=n_steps, ncols=get_tty_columns(), dynamic_ncols=True, desc='[%s] loss: %.5f, accu: %.5f' % (stage, 0.0, 0.0)) if stage not in self.step: self.step[stage] = 1 # toggle trainable parameters of each module need_backward = False for key, toggle in stage_config['toggle'].items(): self.handlers[key].model.train(toggle) for param in self.handlers[key].model.parameters(): param.requires_grad = toggle if toggle: need_backward = True result_list = [] need_revert = 'revert' in stage_config and stage_config['revert'] for batch in progress_bar: self.step[stage] += 1 # single data flow if gen_tags is None: assert isinstance(batch, dict) # insert batch to data data = dict() for key in batch: if torch.cuda.device_count() >= 1: data[key] = batch[key].cuda() else: data[key] = batch[key] # forward for key in stage_config['forward']: data.update(self.handlers[key].model(data)) # multiple data flows else: assert isinstance(batch, tuple) data = dict() for (tag, tag_batch) in zip(gen_tags, batch): tag_data = dict() # insert batch to data for key in tag_batch: if torch.cuda.device_count() >= 1: tag_data[key] = tag_batch[key].cuda() else: tag_data[key] = tag_batch[key] # forward for key in stage_config['forward'][tag]: tag_data.update(self.handlers[key].model(tag_data)) # insert tag data back to the data data.update({ '%s_%s' % (key, tag): tag_data[key] for key in tag_data }) # compute loss and accuracy results = self.metrics[stage_config['metric']](data) # backpropagation if need_backward: results['loss'].backward() for key, toggle in stage_config['toggle'].items(): if toggle: self.optims[key].step() self.optims[key].zero_grad() # compute match for dice score of each case after reversion if need_revert: assert 'prediction' in data, list(data.keys()) assert 'label' in data, list(data.keys()) with torch.set_grad_enabled(False): match, total = match_up( data['prediction'], data['label'], needs_softmax=True, batch_wise=True, threshold=-1, ) results.update({'match': match, 'total': total}) # detach all results, move to CPU, and convert to numpy for key in results: results[key] = results[key].detach().cpu().numpy() # average accuracy if multi-dim assert 'accu' in results if results['accu'].ndim == 0: step_accu = math.nan if results[ 'accu'] == math.nan else results['accu'] else: assert results['accu'].ndim == 1 empty = True for acc in results['accu']: if not np.isnan(acc): empty = False break step_accu = math.nan if empty else np.nanmean(results['accu']) assert 'loss' in results progress_bar.set_description('[%s] loss: %.5f, accu: %.5f' % (stage, results['loss'], step_accu)) if self.logger is not None: self.logger.add_scalar('%s/step/loss' % stage, results['loss'], self.step[stage]) self.logger.add_scalar( '%s/step/accu' % stage, -1 if math.isnan(step_accu) else step_accu, self.step[stage]) result_list.append(results) summary = dict() if need_revert: reverter = Reverter(data_gen) result_collection_blacklist = reverter.revertible scores = dict() progress_bar = tqdm(reverter.on_batches(result_list), total=len(reverter.data_list), dynamic_ncols=True, ncols=get_tty_columns(), desc='[Data index]') for reverted in progress_bar: data_idx = reverted['idx'] scores[data_idx] = reverted['score'] info = '[%s] mean score: %.3f' % ( data_idx, np.mean(list(scores[data_idx].values()))) progress_bar.set_description(info) # summerize score of each class over data indices cls_scores = { cls: np.mean([scores[data_idx][cls] for data_idx in scores]) for cls in class_names } cls_scores.update( {'mean': np.mean([cls_scores[cls] for cls in class_names])}) summary['scores'] = scores summary['cls_scores'] = cls_scores else: result_collection_blacklist = [] # collect results except those revertible ones, e.g., accu, loss summary.update({ key: np.nanmean(np.vstack([result[key] for result in result_list]), axis=0) for key in result_list[0].keys() if key not in result_collection_blacklist }) # process 1D array accu to dictionary of each class score if len(summary['accu']) > 1: assert len(summary['accu']) == len(class_names), (len( summary['accu']), len(class_names)) summary['cls_accu'] = { cls: summary['accu'][i] for (i, cls) in enumerate(class_names) } summary['accu'] = summary['accu'].mean() # print summary info print('Average: ' + ', '.join([ '%s: %.3f' % (key, val) for (key, val) in summary.items() if not isinstance(val, dict) ])) if 'cls_scores' in summary: print('Class score: ' + ', '.join([ '%s: %.3f' % (key, val) for (key, val) in summary['cls_scores'].items() ])) return summary