def save_ouputs(self, data): output_dir = self.config['testing']['output_dir'] ignore_dir = self.config['testing'].get('filename_ignore_dir', True) save_prob = self.config['testing'].get('save_probability', False) label_source = self.config['testing'].get('label_source', None) label_target = self.config['testing'].get('label_target', None) filename_replace_source = self.config['testing'].get( 'filename_replace_source', None) filename_replace_target = self.config['testing'].get( 'filename_replace_target', None) if (not os.path.exists(output_dir)): os.mkdir(output_dir) names, pred = data['names'], data['predict'] prob = scipy.special.softmax(pred, axis=1) output = np.asarray(np.argmax(prob, axis=1), np.uint8) if ((label_source is not None) and (label_target is not None)): output = convert_label(output, label_source, label_target) # save the output and (optionally) probability predictions root_dir = self.config['dataset']['root_dir'] for i in range(len(names)): save_name = names[i].split('/')[-1] if ignore_dir else \ names[i].replace('/', '_') if ((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) print(save_name) save_name = "{0:}/{1:}".format(output_dir, save_name) save_nd_array_as_image(output[i], save_name, root_dir + '/' + names[i]) save_name_split = save_name.split('.') if (not save_prob): continue if ('.nii.gz' in save_name): save_prefix = '.'.join(save_name_split[:-2]) save_format = 'nii.gz' else: save_prefix = '.'.join(save_name_split[:-1]) save_format = save_name_split[-1] class_num = prob.shape[1] for c in range(0, class_num): temp_prob = prob[i][c] prob_save_name = "{0:}_prob_{1:}.{2:}".format( save_prefix, c, save_format) if (len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[i])
def __infer(self): device = torch.device(self.config['testing']['device_name']) self.net.to(device) # laod network parameters and set the network as evaluation mode self.checkpoint = torch.load(self.config['testing']['checkpoint_name']) self.net.load_state_dict(self.checkpoint['model_state_dict']) self.net.train() output_dir = self.config['testing']['output_dir'] save_probability = self.config['testing']['save_probability'] label_source = self.config['testing']['label_source'] label_target = self.config['testing']['label_target'] class_num = self.config['network']['class_num'] mini_batch_size = self.config['testing']['mini_batch_size'] mini_patch_inshape = self.config['testing']['mini_patch_shape'] mini_patch_outshape = None # automatically infer outupt shape if (mini_patch_inshape is not None): patch_inshape = [1, self.config['dataset']['modal_num'] ] + mini_patch_inshape testx = np.random.random(patch_inshape) testx = torch.from_numpy(testx) testx = torch.tensor(testx) testx = testx.to(device) testy = self.net(testx) testy = testy.detach().cpu().numpy() mini_patch_outshape = testy.shape[2:] print('mini patch in shape', mini_patch_inshape) print('mini patch out shape', mini_patch_outshape) start_time = time.time() with torch.no_grad(): for data in self.test_loder: images = data['image'].double() names = data['names'] print(names[0]) data['predict'] = volume_infer(images, self.net, device, class_num, mini_batch_size, mini_patch_inshape, mini_patch_outshape) for i in reversed(range(len(self.transform_list))): if (self.transform_list[i].inverse): data = self.transform_list[ i].inverse_transform_for_prediction(data) output = np.argmax(data['predict'][0], axis=0) output = np.asarray(output, np.uint8) if ((label_source is not None) and (label_target is not None)): output = convert_label(output, label_source, label_target) # save the output and (optionally) probability predictions root_dir = self.config['dataset']['root_dir'] save_name = names[0].split('/')[-1] save_name = "{0:}/{1:}".format(output_dir, save_name) save_nd_array_as_image(output, save_name, root_dir + '/' + names[0]) if (save_probability): save_name_split = save_name.split('.') if ('.nii.gz' in save_name): save_prefix = '.'.join(save_name_split[:-2]) save_format = 'nii.gz' else: save_prefix = '.'.join(save_name_split[:-1]) save_format = save_name_split[-1] prob = scipy.special.softmax(data['predict'][0], axis=0) class_num = prob.shape[0] for c in range(0, class_num): temp_prob = prob[c] prob_save_name = "{0:}_prob_{1:}.{2:}".format( save_prefix, c, save_format) save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[0]) avg_time = (time.time() - start_time) / len(self.test_loder) print("average testing time {0:}".format(avg_time))
def infer(self): device = torch.device(self.config['testing']['device_name']) self.net.to(device) # laod network parameters and set the network as evaluation mode self.checkpoint = torch.load(self.config['testing']['checkpoint_name'], map_location=device) self.net.load_state_dict(self.checkpoint['model_state_dict']) if (self.config['testing']['evaluation_mode'] == True): self.net.eval() if (self.config['testing']['test_time_dropout'] == True): def test_time_dropout(m): if (type(m) == nn.Dropout): print('dropout layer') m.train() self.net.apply(test_time_dropout) output_dir = self.config['testing']['output_dir'] class_num = self.config['network']['class_num'] mini_batch_size = self.config['testing']['mini_batch_size'] mini_patch_inshape = self.config['testing']['mini_patch_input_shape'] mini_patch_outshape = self.config['testing']['mini_patch_output_shape'] mini_patch_stride = self.config['testing']['mini_patch_stride'] output_num = self.config['testing'].get('output_num', 1) multi_pred_avg = self.config['testing'].get('multi_pred_avg', False) save_probability = self.config['testing'].get('save_probability', False) save_var = self.config['testing'].get('save_multi_pred_var', False) label_source = self.config['testing'].get('label_source', None) label_target = self.config['testing'].get('label_target', None) filename_replace_source = self.config['testing'].get( 'filename_replace_source', None) filename_replace_target = self.config['testing'].get( 'filename_replace_target', None) if (not os.path.exists(output_dir)): os.mkdir(output_dir) # automatically infer outupt shape # if(mini_patch_inshape is not None): # patch_inshape = [1, self.config['dataset']['modal_num']] + mini_patch_inshape # testx = np.random.random(patch_inshape) # testx = torch.from_numpy(testx) # testx = torch.tensor(testx) # testx = self.convert_tensor_type(testx) # testx = testx.to(device) # testy = self.net(testx) # if(isinstance(testy, tuple) or isinstance(testy, list)): # testy = testy[0] # testy = testy.detach().cpu().numpy() # mini_patch_outshape = testy.shape[2:] # print('mini patch in shape', mini_patch_inshape) # print('mini patch out shape', mini_patch_outshape) infer_time_list = [] with torch.no_grad(): for data in self.test_loder: images = self.convert_tensor_type(data['image']) names = data['names'] print(names[0]) # for debug # for i in range(images.shape[0]): # image_i = images[i][0] # label_i = images[i][0] # image_name = "temp/{0:}_image.nii.gz".format(names[0]) # label_name = "temp/{0:}_label.nii.gz".format(names[0]) # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) # continue start_time = time.time() data['predict'] = volume_infer(images, self.net, device, class_num, mini_batch_size, mini_patch_inshape, mini_patch_outshape, mini_patch_stride, output_num) for i in reversed(range(len(self.transform_list))): if (self.transform_list[i].inverse): data = self.transform_list[ i].inverse_transform_for_prediction(data) predict_list = [data['predict']] if (isinstance(data['predict'], tuple) or isinstance(data['predict'], list)): predict_list = data['predict'] infer_time = time.time() - start_time infer_time_list.append(infer_time) prob_list = [ scipy.special.softmax(predict[0], axis=0) for predict in predict_list ] if (multi_pred_avg): if (output_num == 1): raise ValueError( "multiple predictions expected, but output_num was set to 1" ) if (output_num != len(prob_list)): raise ValueError( "expected output_num was set to {0:}, but {1:} outputs obtained" .format(output_dir, len(prob_list))) prob_stack = np.asarray(prob_list, np.float32) prob = np.mean(prob_stack, axis=0) var = np.var(prob_stack, axis=0) else: prob = prob_list[0] # output = predict_list[2][0] output = np.asarray(np.argmax(prob, axis=0), np.uint8) if ((label_source is not None) and (label_target is not None)): output = convert_label(output, label_source, label_target) # save the output and (optionally) probability predictions root_dir = self.config['dataset']['root_dir'] save_name = names[0].split('/')[-1] if ((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) save_name = "{0:}/{1:}".format(output_dir, save_name) save_nd_array_as_image(output, save_name, root_dir + '/' + names[0]) save_name_split = save_name.split('.') if ('.nii.gz' in save_name): save_prefix = '.'.join(save_name_split[:-2]) save_format = 'nii.gz' else: save_prefix = '.'.join(save_name_split[:-1]) save_format = save_name_split[-1] if (save_probability): class_num = prob.shape[0] for c in range(0, class_num): temp_prob = prob[c] prob_save_name = "{0:}_prob_{1:}.{2:}".format( save_prefix, c, save_format) if (len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[0]) if (save_var): var = var[1] var_save_name = "{0:}_var.{1:}".format( save_prefix, save_format) save_nd_array_as_image(var, var_save_name, root_dir + '/' + names[0]) infer_time_list = np.asarray(infer_time_list) time_avg = infer_time_list.mean() time_std = infer_time_list.std() print("testing time {0:} +/- {1:}".format(time_avg, time_std))
def infer(self): device = torch.device(self.config['testing']['device_name']) self.net.to(device) # laod network parameters and set the network as evaluation mode self.checkpoint = torch.load(self.config['testing']['checkpoint_name'], map_location=device) self.net.load_state_dict(self.checkpoint['model_state_dict']) if (self.config['testing']['evaluation_mode'] == True): self.net.eval() if (self.config['testing']['test_time_dropout'] == True): def test_time_dropout(m): if (type(m) == nn.Dropout): print('dropout layer') m.train() self.net.apply(test_time_dropout) output_dir = self.config['testing']['output_dir'] save_probability = self.config['testing']['save_probability'] label_source = self.config['testing']['label_source'] label_target = self.config['testing']['label_target'] class_num = self.config['network']['class_num'] mini_batch_size = self.config['testing']['mini_batch_size'] mini_patch_inshape = self.config['testing']['mini_patch_shape'] mini_patch_stride = self.config['testing']['mini_patch_stride'] filename_replace_source = self.config['testing'][ 'filename_replace_source'] filename_replace_target = self.config['testing'][ 'filename_replace_target'] mini_patch_outshape = None # automatically infer outupt shape if (mini_patch_inshape is not None): patch_inshape = [1, self.config['dataset']['modal_num'] ] + mini_patch_inshape testx = np.random.random(patch_inshape) testx = torch.from_numpy(testx) testx = torch.tensor(testx) testx = testx.to(device) testy = self.net(testx) if (isinstance(testy, tuple) or isinstance(testy, list)): testy = testy[0] testy = testy.detach().cpu().numpy() mini_patch_outshape = testy.shape[2:] print('mini patch in shape', mini_patch_inshape) print('mini patch out shape', mini_patch_outshape) start_time = time.time() with torch.no_grad(): for data in self.test_loder: images = self.convert_tensor_type(data['image']) images = data['image'].float() names = data['names'] print(names[0]) # for debug # for i in range(images.shape[0]): # image_i = images[i][0] # label_i = images[i][0] # image_name = "temp/{0:}_image.nii.gz".format(names[0]) # label_name = "temp/{0:}_label.nii.gz".format(names[0]) # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) # continue data['predict'] = volume_infer(images, self.net, device, class_num, mini_batch_size, mini_patch_inshape, mini_patch_outshape, mini_patch_stride) for i in reversed(range(len(self.transform_list))): if (self.transform_list[i].inverse): data = self.transform_list[ i].inverse_transform_for_prediction(data) output = np.argmax(data['predict'][0], axis=0) output = np.asarray(output, np.uint8) if ((label_source is not None) and (label_target is not None)): output = convert_label(output, label_source, label_target) # save the output and (optionally) probability predictions root_dir = self.config['dataset']['root_dir'] save_name = names[0].split('/')[-1] if ((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) save_name = "{0:}/{1:}".format(output_dir, save_name) save_nd_array_as_image(output, save_name, root_dir + '/' + names[0]) if (save_probability): save_name_split = save_name.split('.') if ('.nii.gz' in save_name): save_prefix = '.'.join(save_name_split[:-2]) save_format = 'nii.gz' else: save_prefix = '.'.join(save_name_split[:-1]) save_format = save_name_split[-1] prob = scipy.special.softmax(data['predict'][0], axis=0) class_num = prob.shape[0] for c in range(0, class_num): temp_prob = prob[c] prob_save_name = "{0:}_prob_{1:}.{2:}".format( save_prefix, c, save_format) if (len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[0]) avg_time = (time.time() - start_time) / len(self.test_loder) print("average testing time {0:}".format(avg_time))