Ejemplo n.º 1
0
    def __infer(self):
        device = torch.device(self.config['testing']['device_name'])
        self.net.to(device)
        # laod network parameters and set the network as evaluation mode
        self.checkpoint = torch.load(self.config['testing']['checkpoint_name'])
        self.net.load_state_dict(self.checkpoint['model_state_dict'])
        self.net.train()

        output_dir = self.config['testing']['output_dir']
        save_probability = self.config['testing']['save_probability']
        label_source = self.config['testing']['label_source']
        label_target = self.config['testing']['label_target']
        class_num = self.config['network']['class_num']
        mini_batch_size = self.config['testing']['mini_batch_size']
        mini_patch_inshape = self.config['testing']['mini_patch_shape']
        mini_patch_outshape = None
        # automatically infer outupt shape
        if (mini_patch_inshape is not None):
            patch_inshape = [1, self.config['dataset']['modal_num']
                             ] + mini_patch_inshape
            testx = np.random.random(patch_inshape)
            testx = torch.from_numpy(testx)
            testx = torch.tensor(testx)
            testx = testx.to(device)
            testy = self.net(testx)
            testy = testy.detach().cpu().numpy()
            mini_patch_outshape = testy.shape[2:]
            print('mini patch in shape', mini_patch_inshape)
            print('mini patch out shape', mini_patch_outshape)
        start_time = time.time()
        with torch.no_grad():
            for data in self.test_loder:
                images = data['image'].double()
                names = data['names']
                print(names[0])
                data['predict'] = volume_infer(images, self.net, device,
                                               class_num, mini_batch_size,
                                               mini_patch_inshape,
                                               mini_patch_outshape)

                for i in reversed(range(len(self.transform_list))):
                    if (self.transform_list[i].inverse):
                        data = self.transform_list[
                            i].inverse_transform_for_prediction(data)
                output = np.argmax(data['predict'][0], axis=0)
                output = np.asarray(output, np.uint8)

                if ((label_source is not None) and (label_target is not None)):
                    output = convert_label(output, label_source, label_target)
                # save the output and (optionally) probability predictions
                root_dir = self.config['dataset']['root_dir']
                save_name = names[0].split('/')[-1]
                save_name = "{0:}/{1:}".format(output_dir, save_name)
                save_nd_array_as_image(output, save_name,
                                       root_dir + '/' + names[0])

                if (save_probability):
                    save_name_split = save_name.split('.')
                    if ('.nii.gz' in save_name):
                        save_prefix = '.'.join(save_name_split[:-2])
                        save_format = 'nii.gz'
                    else:
                        save_prefix = '.'.join(save_name_split[:-1])
                        save_format = save_name_split[-1]
                    prob = scipy.special.softmax(data['predict'][0], axis=0)
                    class_num = prob.shape[0]
                    for c in range(0, class_num):
                        temp_prob = prob[c]
                        prob_save_name = "{0:}_prob_{1:}.{2:}".format(
                            save_prefix, c, save_format)
                        save_nd_array_as_image(temp_prob, prob_save_name,
                                               root_dir + '/' + names[0])

        avg_time = (time.time() - start_time) / len(self.test_loder)
        print("average testing time {0:}".format(avg_time))
Ejemplo n.º 2
0
    def infer(self):
        device = torch.device(self.config['testing']['device_name'])
        self.net.to(device)
        # laod network parameters and set the network as evaluation mode
        self.checkpoint = torch.load(self.config['testing']['checkpoint_name'],
                                     map_location=device)
        self.net.load_state_dict(self.checkpoint['model_state_dict'])

        if (self.config['testing']['evaluation_mode'] == True):
            self.net.eval()
            if (self.config['testing']['test_time_dropout'] == True):

                def test_time_dropout(m):
                    if (type(m) == nn.Dropout):
                        print('dropout layer')
                        m.train()

                self.net.apply(test_time_dropout)
        output_dir = self.config['testing']['output_dir']
        save_probability = self.config['testing']['save_probability']
        label_source = self.config['testing']['label_source']
        label_target = self.config['testing']['label_target']
        class_num = self.config['network']['class_num']
        mini_batch_size = self.config['testing']['mini_batch_size']
        mini_patch_inshape = self.config['testing']['mini_patch_shape']
        mini_patch_stride = self.config['testing']['mini_patch_stride']
        filename_replace_source = self.config['testing'][
            'filename_replace_source']
        filename_replace_target = self.config['testing'][
            'filename_replace_target']
        mini_patch_outshape = None
        # automatically infer outupt shape
        if (mini_patch_inshape is not None):
            patch_inshape = [1, self.config['dataset']['modal_num']
                             ] + mini_patch_inshape
            testx = np.random.random(patch_inshape)
            testx = torch.from_numpy(testx)
            testx = torch.tensor(testx)
            testx = testx.to(device)
            testy = self.net(testx)
            if (isinstance(testy, tuple) or isinstance(testy, list)):
                testy = testy[0]
            testy = testy.detach().cpu().numpy()
            mini_patch_outshape = testy.shape[2:]
            print('mini patch in shape', mini_patch_inshape)
            print('mini patch out shape', mini_patch_outshape)
        start_time = time.time()
        with torch.no_grad():
            for data in self.test_loder:
                images = self.convert_tensor_type(data['image'])
                images = data['image'].float()
                names = data['names']
                print(names[0])
                # for debug
                # for i in range(images.shape[0]):
                #     image_i = images[i][0]
                #     label_i = images[i][0]
                #     image_name = "temp/{0:}_image.nii.gz".format(names[0])
                #     label_name = "temp/{0:}_label.nii.gz".format(names[0])
                #     save_nd_array_as_image(image_i, image_name, reference_name = None)
                #     save_nd_array_as_image(label_i, label_name, reference_name = None)
                # continue
                data['predict'] = volume_infer(images, self.net, device,
                                               class_num, mini_batch_size,
                                               mini_patch_inshape,
                                               mini_patch_outshape,
                                               mini_patch_stride)

                for i in reversed(range(len(self.transform_list))):
                    if (self.transform_list[i].inverse):
                        data = self.transform_list[
                            i].inverse_transform_for_prediction(data)
                output = np.argmax(data['predict'][0], axis=0)
                output = np.asarray(output, np.uint8)

                if ((label_source is not None) and (label_target is not None)):
                    output = convert_label(output, label_source, label_target)
                # save the output and (optionally) probability predictions
                root_dir = self.config['dataset']['root_dir']
                save_name = names[0].split('/')[-1]
                if ((filename_replace_source is not None)
                        and (filename_replace_target is not None)):
                    save_name = save_name.replace(filename_replace_source,
                                                  filename_replace_target)
                save_name = "{0:}/{1:}".format(output_dir, save_name)
                save_nd_array_as_image(output, save_name,
                                       root_dir + '/' + names[0])
                if (save_probability):
                    save_name_split = save_name.split('.')
                    if ('.nii.gz' in save_name):
                        save_prefix = '.'.join(save_name_split[:-2])
                        save_format = 'nii.gz'
                    else:
                        save_prefix = '.'.join(save_name_split[:-1])
                        save_format = save_name_split[-1]
                    prob = scipy.special.softmax(data['predict'][0], axis=0)
                    class_num = prob.shape[0]
                    for c in range(0, class_num):
                        temp_prob = prob[c]
                        prob_save_name = "{0:}_prob_{1:}.{2:}".format(
                            save_prefix, c, save_format)
                        if (len(temp_prob.shape) == 2):
                            temp_prob = np.asarray(temp_prob * 255, np.uint8)
                        save_nd_array_as_image(temp_prob, prob_save_name,
                                               root_dir + '/' + names[0])

        avg_time = (time.time() - start_time) / len(self.test_loder)
        print("average testing time {0:}".format(avg_time))