def __infer(self): device = torch.device(self.config['testing']['device_name']) self.net.to(device) # laod network parameters and set the network as evaluation mode self.checkpoint = torch.load(self.config['testing']['checkpoint_name']) self.net.load_state_dict(self.checkpoint['model_state_dict']) self.net.train() output_dir = self.config['testing']['output_dir'] save_probability = self.config['testing']['save_probability'] label_source = self.config['testing']['label_source'] label_target = self.config['testing']['label_target'] class_num = self.config['network']['class_num'] mini_batch_size = self.config['testing']['mini_batch_size'] mini_patch_inshape = self.config['testing']['mini_patch_shape'] mini_patch_outshape = None # automatically infer outupt shape if (mini_patch_inshape is not None): patch_inshape = [1, self.config['dataset']['modal_num'] ] + mini_patch_inshape testx = np.random.random(patch_inshape) testx = torch.from_numpy(testx) testx = torch.tensor(testx) testx = testx.to(device) testy = self.net(testx) testy = testy.detach().cpu().numpy() mini_patch_outshape = testy.shape[2:] print('mini patch in shape', mini_patch_inshape) print('mini patch out shape', mini_patch_outshape) start_time = time.time() with torch.no_grad(): for data in self.test_loder: images = data['image'].double() names = data['names'] print(names[0]) data['predict'] = volume_infer(images, self.net, device, class_num, mini_batch_size, mini_patch_inshape, mini_patch_outshape) for i in reversed(range(len(self.transform_list))): if (self.transform_list[i].inverse): data = self.transform_list[ i].inverse_transform_for_prediction(data) output = np.argmax(data['predict'][0], axis=0) output = np.asarray(output, np.uint8) if ((label_source is not None) and (label_target is not None)): output = convert_label(output, label_source, label_target) # save the output and (optionally) probability predictions root_dir = self.config['dataset']['root_dir'] save_name = names[0].split('/')[-1] save_name = "{0:}/{1:}".format(output_dir, save_name) save_nd_array_as_image(output, save_name, root_dir + '/' + names[0]) if (save_probability): save_name_split = save_name.split('.') if ('.nii.gz' in save_name): save_prefix = '.'.join(save_name_split[:-2]) save_format = 'nii.gz' else: save_prefix = '.'.join(save_name_split[:-1]) save_format = save_name_split[-1] prob = scipy.special.softmax(data['predict'][0], axis=0) class_num = prob.shape[0] for c in range(0, class_num): temp_prob = prob[c] prob_save_name = "{0:}_prob_{1:}.{2:}".format( save_prefix, c, save_format) save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[0]) avg_time = (time.time() - start_time) / len(self.test_loder) print("average testing time {0:}".format(avg_time))
def infer(self): device = torch.device(self.config['testing']['device_name']) self.net.to(device) # laod network parameters and set the network as evaluation mode self.checkpoint = torch.load(self.config['testing']['checkpoint_name'], map_location=device) self.net.load_state_dict(self.checkpoint['model_state_dict']) if (self.config['testing']['evaluation_mode'] == True): self.net.eval() if (self.config['testing']['test_time_dropout'] == True): def test_time_dropout(m): if (type(m) == nn.Dropout): print('dropout layer') m.train() self.net.apply(test_time_dropout) output_dir = self.config['testing']['output_dir'] save_probability = self.config['testing']['save_probability'] label_source = self.config['testing']['label_source'] label_target = self.config['testing']['label_target'] class_num = self.config['network']['class_num'] mini_batch_size = self.config['testing']['mini_batch_size'] mini_patch_inshape = self.config['testing']['mini_patch_shape'] mini_patch_stride = self.config['testing']['mini_patch_stride'] filename_replace_source = self.config['testing'][ 'filename_replace_source'] filename_replace_target = self.config['testing'][ 'filename_replace_target'] mini_patch_outshape = None # automatically infer outupt shape if (mini_patch_inshape is not None): patch_inshape = [1, self.config['dataset']['modal_num'] ] + mini_patch_inshape testx = np.random.random(patch_inshape) testx = torch.from_numpy(testx) testx = torch.tensor(testx) testx = testx.to(device) testy = self.net(testx) if (isinstance(testy, tuple) or isinstance(testy, list)): testy = testy[0] testy = testy.detach().cpu().numpy() mini_patch_outshape = testy.shape[2:] print('mini patch in shape', mini_patch_inshape) print('mini patch out shape', mini_patch_outshape) start_time = time.time() with torch.no_grad(): for data in self.test_loder: images = self.convert_tensor_type(data['image']) images = data['image'].float() names = data['names'] print(names[0]) # for debug # for i in range(images.shape[0]): # image_i = images[i][0] # label_i = images[i][0] # image_name = "temp/{0:}_image.nii.gz".format(names[0]) # label_name = "temp/{0:}_label.nii.gz".format(names[0]) # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) # continue data['predict'] = volume_infer(images, self.net, device, class_num, mini_batch_size, mini_patch_inshape, mini_patch_outshape, mini_patch_stride) for i in reversed(range(len(self.transform_list))): if (self.transform_list[i].inverse): data = self.transform_list[ i].inverse_transform_for_prediction(data) output = np.argmax(data['predict'][0], axis=0) output = np.asarray(output, np.uint8) if ((label_source is not None) and (label_target is not None)): output = convert_label(output, label_source, label_target) # save the output and (optionally) probability predictions root_dir = self.config['dataset']['root_dir'] save_name = names[0].split('/')[-1] if ((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) save_name = "{0:}/{1:}".format(output_dir, save_name) save_nd_array_as_image(output, save_name, root_dir + '/' + names[0]) if (save_probability): save_name_split = save_name.split('.') if ('.nii.gz' in save_name): save_prefix = '.'.join(save_name_split[:-2]) save_format = 'nii.gz' else: save_prefix = '.'.join(save_name_split[:-1]) save_format = save_name_split[-1] prob = scipy.special.softmax(data['predict'][0], axis=0) class_num = prob.shape[0] for c in range(0, class_num): temp_prob = prob[c] prob_save_name = "{0:}_prob_{1:}.{2:}".format( save_prefix, c, save_format) if (len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[0]) avg_time = (time.time() - start_time) / len(self.test_loder) print("average testing time {0:}".format(avg_time))