def sample_to_sample_plus(self, samples, cfg, datamode):

        new_samples = []
        # surface_point_count = 100
        for sample in samples:

            x = sample.x
            y = sample.y

            y = (y > 0).long()

            center = tuple([d // 2 for d in x.shape])
            x = crop(x, cfg.patch_shape, center)
            y = crop(y, cfg.patch_shape, center)

            # shape = torch.tensor(y.shape)[None].float()
            # y_outer = sample_outer_surface_in_voxel(y)

            # x_super_res = torch.tensor(1)
            # y_super_res = torch.tensor(1)
            new_samples += [
                SamplePlus(x.cpu(), y.cpu(), sample.name, sample.spharm_coeffs)
            ]

        return new_samples
Beispiel #2
0
    def register_atlas(self, atlas, pose, size_out):

        direction = torch.Tensor([0, -1, 0]).cuda()[None]
        q = direction + pose[:, :3]
        q = F.normalize(q, dim=1)
        theta_q = transforms.rot_matrix_from_quaternion(q)

        theta_shift = torch.eye(4, device=pose.device)[None].repeat(
            pose.shape[0], 1, 1).cuda().float()
        theta_shift[:, :3, 3] = pose[:, 3:6]

        # theta_scale = torch.eye(4, device=pose.device)[None].repeat(pose.shape[0], 1, 1).float() * axis[:, 6]
        # theta_scale[:, 3, 3] = 1
        # theta = theta_scale @ theta_q @ theta_shift

        theta = theta_q @ theta_shift

        theta = theta[:, :3]
        grid = affine_3d_grid_generator.affine_grid(theta,
                                                    atlas.size()).double()
        rotated = F.grid_sample(atlas.double(),
                                grid,
                                mode='bilinear',
                                padding_mode='zeros').float().detach()

        _, _, D_out, H_out, W_out = size_out
        N_in, C_in, D_in, H_in, W_in = rotated.shape
        center = (D_in // 2, H_in // 2, W_in // 2)
        rotated = crop(rotated, (N_in, C_in, D_out, H_out, W_out), (
            0,
            self.config.config.prior_channel_count // 2,
        ) + center)

        return rotated
Beispiel #3
0
    def read_sample(self, data_root, sample, out_shape, pad_shape):
        #x = np.load('{}/{}'.format(data_root, sample))
        mri_image = nib.load('{}/{}'.format(data_root, sample))
        mri_image_data = mri_image.get_fdata()
        # print("mri_image_data")
        # print(mri_image_data)
        #y = np.load('{}/labelsTr/{}'.format(data_root, sample))

        #D, H, W = x.shape
        # print("mri_image_data.shape")
        # print(mri_image_data.shape)
        D, H, W = mri_image_data.shape
        center_z, center_y, center_x = D // 2, H // 2, W // 2
        D, H, W = pad_shape
       # x = crop(x, (D, H, W), (center_z, center_y, center_x))
        x = crop(mri_image_data, (D, H, W), (center_z, center_y, center_x))
        #y = crop(y, (D, H, W), (center_z, center_y, center_x))
        # print("x shape after crop")
        # print(x.shape)
        # print("type of x after crop")
        # print(type(x))
        x = torch.from_numpy(x).float()

        #y = torch.from_numpy(y).float()

        x = F.interpolate(x[None, None], out_shape, mode='trilinear', align_corners=False)[0, 0]
        #y = F.interpolate(y[None, None].float(), out_shape, mode='nearest')[0, 0].long()

        return x
Beispiel #4
0
def sample_to_sample_plus(samples, cfg, datamode):

    new_samples = []
    # surface_point_count = 100
    for sample in samples: 
         
        x = sample.x
        y = sample.y 

        y = (y>0).long()

        center = tuple([d // 2 for d in x.shape]) 
        x = crop(x, cfg.patch_shape, center) 
        y = crop(y, cfg.patch_shape, center)   

        shape = torch.tensor(y.shape)[None].float()
        y_outer = sample_outer_surface_in_voxel(y) 

        new_samples += [SamplePlus(x.cpu(), y.cpu(), y_outer.cpu(), shape=shape)]

    return new_samples
    def read_sample(self, data_root, sample, out_shape, pad_shape):
        x = np.load('{}/imagesTr/{}'.format(data_root, sample))
        y = np.load('{}/labelsTr/{}'.format(data_root, sample))

        D, H, W = x.shape
        center_z, center_y, center_x = D // 2, H // 2, W // 2
        D, H, W = pad_shape
        x = crop(x, (D, H, W), (center_z, center_y, center_x))
        y = crop(y, (D, H, W), (center_z, center_y, center_x))

        x = torch.from_numpy(x).float()
        y = torch.from_numpy(y).float()

        x = F.interpolate(x[None, None],
                          out_shape,
                          mode='trilinear',
                          align_corners=False)[0, 0]
        y = F.interpolate(y[None, None].float(), out_shape,
                          mode='nearest')[0, 0].long()

        return x, y
Beispiel #6
0
    def __getitem__(self, idx):
        item = self.data[idx]
        x = torch.from_numpy(item.x).cuda()[None]
        y = torch.from_numpy(item.y).cuda().long()

        orientation = torch.from_numpy(item.orientation).float()

        if self.mode == DataModes.TRAINING:  # if training do augmentation
            new_orientation = (torch.rand(3) -
                               0.5) * 2 * self.cfg.augmentation_shift_range
            new_orientation = F.normalize(new_orientation, dim=0)
            q = orientation + new_orientation
            q = F.normalize(q, dim=0)
            theta_rotate = transforms.rot_matrix_from_quaternion(q[None])

            shift = torch.tensor([
                d / (D // 2) for d, D in zip(
                    2 * (torch.rand(3) - 0.5) *
                    self.cfg.augmentation_shift_range, y.shape)
            ])
            theta_shift = transforms.shift(shift)
            theta = theta_rotate @ theta_shift

            x, y = transforms.transform(theta, x, y)
            orientation = new_orientation

            pose = torch.cat((orientation, shift)).cuda()
        else:
            pose = torch.zeros(6).cuda()

        C, D, H, W = x.shape
        center = (D // 2, H // 2, W // 2)

        x = crop(x, (C, ) + self.cfg.patch_shape, (0, ) + center)
        y = crop(y, self.cfg.patch_shape, center)

        if self.cfg.model_name == 'panet':
            y = [y, pose]

        return x, y
Beispiel #7
0
    def pre_process_dataset(self, cfg, trial_id):
        '''
         :
        '''
 
        data_root = cfg.dataset_path
        samples = [dir for dir in os.listdir(data_root)]
 
        pad_shape = (384, 384, 384)
        inputs = []
        labels = []

        print('Data pre-processing')
        for sample in samples:
            if 'pickle' not in sample:
                print('.', end='', flush=True)
                x = [] 
                images_path = [dir for dir in os.listdir('{}/{}/DICOM_anon'.format(data_root, sample)) if 'dcm' in dir]
                for image_path in images_path:
                    file = pydicom.dcmread('{}/{}/DICOM_anon/{}'.format(data_root, sample, image_path))
                    x += [file.pixel_array] 

                d_resolution = file.SliceThickness
                h_resolution, w_resolution = file.PixelSpacing 
                x = np.float32(np.array(x))

 
                D, H, W = x.shape
                D = int(D * d_resolution) #  
                H = int(H * h_resolution) # 
                W = int(W * w_resolution)  #  
                # we resample such that 1 pixel is 1 mm in x,y and z directiions
                base_grid = torch.zeros((1, D, H, W, 3))
                w_points = (torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1]))
                h_points = (torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])).unsqueeze(-1)
                d_points = (torch.linspace(-1, 1, D) if D > 1 else torch.Tensor([-1])).unsqueeze(-1).unsqueeze(-1)
                base_grid[:, :, :, :, 0] = w_points
                base_grid[:, :, :, :, 1] = h_points
                base_grid[:, :, :, :, 2] = d_points
                
                grid = base_grid.cuda()
                 
                
                x = torch.from_numpy(x).cuda()
                x = F.grid_sample(x[None, None], grid, mode='bilinear', padding_mode='border')[0, 0]
                x = x.data.cpu().numpy() 
                #----
                 
                x = np.float32(x) 
                mean_x = np.mean(x)
                std_x = np.std(x)

                D, H, W = x.shape
                center_z, center_y, center_x = D // 2, H // 2, W // 2
                D, H, W = pad_shape
                x = crop(x, (D, H, W), (center_z, center_y, center_x))  
  
                # normalize x
                x = (x - mean_x)/std_x
                x = torch.from_numpy(x)
                inputs += [x]
                 
                #----
 
                y = [] 
                images_path = [dir for dir in os.listdir('{}/{}/Ground'.format(data_root, sample)) if 'png' in dir]
                for image_path in images_path:
                    file = io.imread('{}/{}/Ground/{}'.format(data_root, sample, image_path))
                    y += [file]  
                 
                y = np.array(y) 
                y = np.int64(y) 

                y = torch.from_numpy(y).cuda()
                y = F.grid_sample(y[None, None].float(), grid, mode='nearest', padding_mode='border')[0, 0]
                y = y.data.cpu().numpy()

                 
               
                y = np.int64(y)
                y = crop(y, (D, H, W), (center_z, center_y, center_x))  
                  
                 
                y = torch.from_numpy(y/255) 
                  
                labels += [y]

        print('\nSaving pre-processed data to disk')
        np.random.seed(0)
        perm = np.random.permutation(len(inputs))
        tr_length = cfg.training_set_size
        counts = [perm[:tr_length], perm[len(inputs)//2:]]
        # counts = [perm[:tr_length], perm[16:]]


        data = {}
        down_sample_shape = cfg.patch_shape

        input_shape = x.shape
        scale_factor = (np.max(down_sample_shape)/np.max(input_shape))

        for i, datamode in enumerate([DataModes.TRAINING, DataModes.TESTING]):

            samples = []
 

            for j in counts[i]: 
                print('.',end='', flush=True)
                x = inputs[j]
                y = labels[j]

                x = F.interpolate(x[None, None], scale_factor=scale_factor, mode='trilinear')[0, 0]
                y = F.interpolate(y[None, None].float(), scale_factor=scale_factor, mode='nearest')[0, 0].long()

                samples.append(Sample(x, y)) 

            with open(data_root + '/pre_loaded_data_{}_{}_v3.pickle'.format(datamode, "_".join(map(str, down_sample_shape))), 'wb') as handle:
                pickle.dump(samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

            data[datamode] = ChaosDataset(samples, cfg, datamode)
        
        print('\n***************************************\n\
            Pre-processing complete. Now comment function load_data in main.py and uncomment function quick_load_data.\n\
            ***************************************')
        sys.exit()
        return data
Beispiel #8
0
    def load_data(self, cfg):
        '''
        # Change this to load your training data.

        # pre-synaptic neuron   :   1
        # synapse               :   2
        # post-synaptic neuron  :   3
        # background            :   0
        '''

        data_root = '/cvlabdata1/cvlab/datasets_udaranga/datasets/3d/CortexEPFL/'
        num_classes = 4
        path_images = data_root + 'imagestack.tif'
        path_synapse = data_root + 'labels/labels_synapses.tif'
        path_pre_post = data_root + 'labels/labels_pre_post.tif'
        ''' Label information '''
        path_idx = data_root + 'labels/info.npy'
        idx = np.load(path_idx)
        ''' Load data '''
        x = io.imread(path_images)[:200]
        y_synapses = io.imread(path_synapse)
        y_pre_post = io.imread(path_pre_post)

        x = np.float32(x) / 255
        y_synapses = np.int64(y_synapses)

        # Syn at bottom
        temp = np.int64(y_pre_post)
        y_pre_post = np.copy(y_synapses)
        y_pre_post[temp > 0] = temp[temp > 0]

        # method 1: split neurons
        counts = [[0, 12], [12, 24], [24, 36]]

        data = {}

        patch_shape_extended = tuple([
            int(np.sqrt(2) * i) + 2 * cfg.augmentation_shift_range
            for i in cfg.patch_shape
        ])  # to allow augmentation and crop

        for i, data_mode in enumerate(
            [DataModes.TRAINING, DataModes.VALIDATION, DataModes.TESTING]):

            samples = []
            for j in range(counts[i][0], counts[i][1]):

                points = np.where(y_synapses == idx[j][1])
                centre = tuple(np.mean(points, axis=1, dtype=np.int64))

                # extract the object of interesete
                y = np.zeros_like(y_pre_post)
                for k, id in enumerate(idx[j][:3]):
                    y[y_pre_post == id] = k + 1

                patch_y = crop(y, patch_shape_extended, centre)
                patch_x = crop(x, patch_shape_extended, centre)

                # Compute orientation

                # First find the Axis
                syn = patch_y == 2
                coords = np.array(np.where(syn)).transpose()
                syn_center = np.flip(np.mean(coords, axis=0))
                pca = PCA(n_components=3)
                pca.fit(coords)
                u = -np.flip(pca.components_)[0]

                # Now decide it directed towards pre syn region
                pre = patch_y == 1
                coords = np.array(np.where(pre)).transpose()
                pre_center = np.flip(np.mean(coords, axis=0))

                w = pre_center - syn_center
                angle = np.arccos(
                    np.dot(u, w) / norm(u) / norm(w)) * 180 / np.pi
                if angle > 90:
                    u = -u

                orientation = u
                scale = 1

                samples.append(
                    Sample(patch_x, patch_y, orientation, centre, scale))

            data[data_mode] = CortexVoxelDataset(samples, cfg, data_mode)

        with open(data_root + 'labels/pre_computed.pickle', 'wb') as handle:
            pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

        return data
Beispiel #9
0
    def load_data(self, cfg, trial_id):
        '''
         :
        '''

        data_root = '/cvlabsrc1' + volume_suffix + '/cvlab/datasets_udaranga/datasets/3d/chaos/Train_Sets/CT'
        samples = [dir for dir in os.listdir(data_root)]

        pad_shape = (384, 384, 384)
        inputs = []
        labels = []

        for sample in samples:
            if 'pickle' not in sample:
                print(sample)
                x = []
                images_path = [
                    dir for dir in os.listdir('{}/{}/DICOM_anon'.format(
                        data_root, sample)) if 'dcm' in dir
                ]
                for image_path in images_path:
                    file = pydicom.dcmread('{}/{}/DICOM_anon/{}'.format(
                        data_root, sample, image_path))
                    x += [file.pixel_array]

                d_resolution = file.SliceThickness
                h_resolution, w_resolution = file.PixelSpacing
                x = np.float32(np.array(x))

                # clip: x
                # CHAOS CHALLENGE: MedianCHAOS
                # Vladimir Groza from Median Technologies: CHAOS 1st place solution overview.
                # embed()
                # x[x<(1000-160)] = 1000-160
                # x[x>(1000+240)] = 1000+240
                # x = (x - x.min())/(x.max()-x.min())

                # io.imsave('/cvlabdata2/cvlab/datasets_udaranga/check1006.tif', np.uint8(x * 255))
                # x = io.imread('{}/{}/DICOM_anon/volume.tif'.format(data_root, sample))
                # x = np.float32(x)/2500
                # x[x>1] = 1
                #
                D, H, W = x.shape
                D = int(D * d_resolution)  #
                H = int(H * h_resolution)  #
                W = int(W * w_resolution)  #
                # we resample such that 1 pixel is 1 mm in x,y and z directiions
                base_grid = torch.zeros((1, D, H, W, 3))
                w_points = (torch.linspace(-1, 1, W)
                            if W > 1 else torch.Tensor([-1]))
                h_points = (torch.linspace(-1, 1, H)
                            if H > 1 else torch.Tensor([-1])).unsqueeze(-1)
                d_points = (torch.linspace(-1, 1, D) if D > 1 else
                            torch.Tensor([-1])).unsqueeze(-1).unsqueeze(-1)
                base_grid[:, :, :, :, 0] = w_points
                base_grid[:, :, :, :, 1] = h_points
                base_grid[:, :, :, :, 2] = d_points

                grid = base_grid.cuda()

                x = torch.from_numpy(x).cuda()
                x = F.grid_sample(x[None, None],
                                  grid,
                                  mode='bilinear',
                                  padding_mode='border')[0, 0]
                x = x.data.cpu().numpy()
                #----

                x = np.float32(x)
                mean_x = np.mean(x)
                std_x = np.std(x)

                D, H, W = x.shape
                center_z, center_y, center_x = D // 2, H // 2, W // 2
                D, H, W = pad_shape
                x = crop(x, (D, H, W), (center_z, center_y, center_x))

                # io.imsave('{}/{}/DICOM_anon/volume_resampled_2.tif'.format(data_root, sample), np.uint16(x))

                x = (x - mean_x) / std_x
                x = torch.from_numpy(x)
                inputs += [x]

                #----

                y = []
                images_path = [
                    dir for dir in os.listdir('{}/{}/Ground'.format(
                        data_root, sample)) if 'png' in dir
                ]
                for image_path in images_path:
                    file = io.imread('{}/{}/Ground/{}'.format(
                        data_root, sample, image_path))
                    y += [file]

                y = np.array(y)
                y = np.int64(y)

                y = torch.from_numpy(y).cuda()
                y = F.grid_sample(y[None, None].float(),
                                  grid,
                                  mode='nearest',
                                  padding_mode='border')[0, 0]
                y = y.data.cpu().numpy()

                y = np.int64(y)
                y = crop(y, (D, H, W), (center_z, center_y, center_x))

                # io.imsave('{}/{}/Ground/labels_resampled_2.tif'.format(data_root, sample), np.uint8(y))

                y = torch.from_numpy(y / 255)

                # y = np.uint8(y.data.cpu().numpy())
                # y = np.sum(y, axis=1)
                # y = np.sum(y, axis=1)
                # se = np.where(y>0)
                # embed()
                # print('{} {} {}'.format(sample, y.shape[0], se[0][-1]-se[0][0]))
                # print('{} {} {} {} {} {}'.format(sample, y.shape[0], y.shape[1], y.shape[2], se[0][0], se[0][-1]))
                labels += [y]

        # raise Exception()

        # inputs = []
        # labels = []
        # for sample in samples:

        #     if 'pickle' not in sample:
        #         print(sample)
        #         x = io.imread('{}/{}/DICOM_anon/volume_resampled_2.tif'.format(data_root, sample))

        #         inputs += [x]
        #         # print(sample)
        #         # print(x.shape)

        #         y = io.imread('{}/{}/Ground/labels_resampled_2.tif'.format(data_root, sample))
        #         y = np.int64(y/255)
        #         y = crop(y, (D, H, W), (center_z, center_y, center_x))
        #         y = torch.from_numpy(y)
        #         labels += [y]
        # raise Exception()
        # print('loaded')
        # fix shuffle
        np.random.seed(1)
        perm = np.random.permutation(len(inputs))
        tr_length = cfg.training_set_size
        counts = [perm[:tr_length], perm[len(inputs) // 2:]]
        # counts = [perm[:tr_length], perm[16:]]

        data = {}
        down_sample_shape = cfg.hint_patch_shape

        input_shape = x.shape
        scale_factor = (np.max(down_sample_shape) / np.max(input_shape))

        for i, datamode in enumerate([DataModes.TRAINING, DataModes.TESTING]):

            samples = []
            print(i)
            print('--')

            for j in counts[i]:
                print(j)
                x = inputs[j]
                y = labels[j]

                x = F.interpolate(x[None, None],
                                  scale_factor=scale_factor,
                                  mode='trilinear')[0, 0]
                y = F.interpolate(y[None, None].float(),
                                  scale_factor=scale_factor,
                                  mode='nearest')[0, 0].long()

                new_samples = sample_to_sample_plus([Sample(x, y)], cfg,
                                                    datamode)
                samples.append(new_samples[0])
                # print('A BREAK IS HERE!!!!!!!!!!!!!!!!!!!!!<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
                # break

            with open(
                    data_root + '/pre_loaded_data_{}_{}.pickle'.format(
                        datamode, "_".join(map(str, down_sample_shape))),
                    'wb') as handle:
                pickle.dump(samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

            data[datamode] = ChaosDataset(samples, cfg, datamode)
        print('-end-')
        data[DataModes.TRAINING_EXTENDED] = ChaosDataset(
            data[DataModes.TRAINING].data, cfg, DataModes.TRAINING_EXTENDED)
        data[DataModes.VALIDATION] = data[DataModes.TESTING]
        # raise Exception()
        return data
    def load_data(self, cfg):
        data_root = cfg.data_root
        samples = [dir for dir in os.listdir(data_root)]

        prepare_samples = []

        for sample in samples:
            if 'pickle' not in sample:
                print(sample)

                x = []
                images_path = [dir for dir in os.listdir(
                    '{}/{}/DICOM_anon'.format(data_root, sample)) if 'dcm' in dir]
                for image_path in images_path:
                    file = pydicom.dcmread(
                        '{}/{}/DICOM_anon/{}'.format(data_root, sample, image_path))
                    x += [file.pixel_array]

                d_resolution = file.SliceThickness
                h_resolution, w_resolution = file.PixelSpacing
                x = np.float32(np.array(x))

                # clip: x
                # CHAOS CHALLENGE: MedianCHAOS
                # Vladimir Groza from Median Technologies: CHAOS 1st place solution overview.
                # embed()
                # x[x<(1000-160)] = 1000-160
                # x[x>(1000+240)] = 1000+240
                # x = (x - x.min())/(x.max()-x.min())

                # io.imsave('/cvlabdata2/cvlab/datasets_udaranga/check1006.tif', np.uint8(x * 255))
                # x = io.imread('{}/{}/DICOM_anon/volume.tif'.format(data_root, sample))
                # x = np.float32(x)/2500
                # x[x>1] = 1
                #
                D, H, W = x.shape
                D = int(D * d_resolution)
                H = int(H * h_resolution)
                W = int(W * w_resolution)
                # we resample such that 1 pixel is 1 mm in x,y and z directiions
                base_grid = torch.zeros((1, D, H, W, 3))
                w_points = (torch.linspace(-1, 1, W) if W >
                            1 else torch.Tensor([-1]))
                h_points = (torch.linspace(-1, 1, H) if H >
                            1 else torch.Tensor([-1])).unsqueeze(-1)
                d_points = (torch.linspace(-1, 1, D) if D >
                            1 else torch.Tensor([-1])).unsqueeze(-1).unsqueeze(-1)
                base_grid[:, :, :, :, 0] = w_points
                base_grid[:, :, :, :, 1] = h_points
                base_grid[:, :, :, :, 2] = d_points

                grid = base_grid  # .cuda() TODO

                x = torch.from_numpy(x)  # .cuda() TODO
                x = F.grid_sample(
                    x[None, None], grid, mode='bilinear', padding_mode='border')[0, 0]
                x = x.data.cpu().numpy()
                # ----

                x = np.float32(x)
                mean_x = np.mean(x)
                std_x = np.std(x)

                D, H, W = x.shape
                center_z, center_y, center_x = D // 2, H // 2, W // 2
                D, H, W = cfg.pad_shape
                x = crop(x, (D, H, W), (center_z, center_y, center_x))

                # io.imsave('{}/{}/DICOM_anon/volume_resampled_2.tif'.format(data_root, sample), np.uint16(x))

                x = (x - mean_x)/std_x
                x = torch.from_numpy(x)

                # ----

                y = []
                images_path = [dir for dir in os.listdir(
                    '{}/{}/Ground'.format(data_root, sample)) if 'png' in dir]
                for image_path in images_path:
                    file = io.imread(
                        '{}/{}/Ground/{}'.format(data_root, sample, image_path))
                    y += [file]

                y = np.array(y)
                y = np.int64(y)

                y = torch.from_numpy(y)  # .cuda() TODO
                y = F.grid_sample(y[None, None].float(), grid,
                                  mode='nearest', padding_mode='border')[0, 0]
                y = y.data.cpu().numpy()

                y = np.int64(y)
                y = crop(y, (D, H, W), (center_z, center_y, center_x))

                y = torch.from_numpy(y/255)

                prepare_samples.append(PrepareSample(x, y, sample))

        if not os.path.exists(cfg.data_root):
            os.makedirs(cfg.data_root)

        with open(cfg.loaded_data_path, 'wb') as handle:
            pickle.dump(prepare_samples, handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
Beispiel #11
0
def sample_to_sample_plus(samples, cfg, datamode):

    new_samples = []
    # surface_point_count = 100
    for sample in samples:
        if cfg.low_resolution is not None:
            x_super_res = sample.x.cuda().float()
            y_super_res = sample.y.cuda().long()

            high_res, _, _ = x_super_res.shape
            D = high_res // cfg.low_resolution[0]
            K = torch.zeros(1, 1, D, D, D).cuda().float()
            K[0, 0, D // 2 - 1:D // 2 + 1, D // 2 - 1:D // 2 + 1,
              D // 2 - 1:D // 2 + 1] = 1

            x = F.conv3d(x_super_res[None, None], K, bias=None, stride=D)[0, 0]
            y = (F.conv3d(
                y_super_res[None, None].float(), K, bias=None, stride=D)[0, 0]
                 > 4).long()
            # x_       = F.interpolate(x_super_res[None, None], cfg.low_resolution, mode='trilinear')[0, 0]
            # y_       = F.interpolate(y_super_res[None, None].float(), cfg.low_resolution, mode='nearest').long()[0, 0]

            # embed()

            y_outer = sample_outer_surface_in_voxel(y_super_res)
            shape = torch.tensor(y_super_res.shape)[None].float()
            y_outer = torch.nonzero(y_outer)
            y_outer = torch.flip(y_outer, dims=[1])  # x,y,z
            y_outer = normalize_vertices(y_outer, shape)

            # io.imsave('/cvlabdata2/cvlab/datasets_udaranga/check1006.tif', np.uint8(x_super_res.data.cpu().numpy() * 255))
            # vertices_ = torch.floor(y_outer * 32 + 32).long()
            # # vertices_ = torch.floor(y_outer.float()/6).long()
            # y_outer_ = torch.zeros_like(y)
            # y_outer_[vertices_[:,2], vertices_[:,1], vertices_[:,0]] = 1
            # y_outer_ = y_outer_ + 3*y
            # y_outer_ = y_outer_.data.cpu().numpy()
            # io.imsave('/cvlabdata2/cvlab/datasets_udaranga/check1006.tif', np.uint8(y_outer_ * 63))

            high_res, _, _ = x_super_res.shape
            D = high_res // 64
            K = torch.zeros(1, 1, D, D, D).cuda().float()
            K[0, 0, D // 2 - 1:D // 2 + 1, D // 2 - 1:D // 2 + 1,
              D // 2 - 1:D // 2 + 1] = 1

            x_super_res = F.conv3d(x_super_res[None, None],
                                   K,
                                   bias=None,
                                   stride=D)[0, 0]
            y_super_res = (F.conv3d(
                y_super_res[None, None].float(), K, bias=None, stride=D)[0, 0]
                           > 4).long()
            print(crash)
            # y_super_res = y_super_res.long()

        else:
            x = sample.x
            y = sample.y

            y = (y > 0).long()

            # y_outer = sample_outer_surface_in_voxel(y)
            # y_outer = torch.nonzero(y_outer)

            center = tuple([d // 2 for d in x.shape])
            x = crop(x, cfg.hint_patch_shape, center)
            y = crop(y, cfg.hint_patch_shape, center)

            shape = torch.tensor(y.shape)[None].float()
            y_outer = sample_outer_surface_in_voxel(y)

            # point_count = 100
            # # print(point_count)
            # idxs = torch.nonzero(border)
            # y_outer = torch.zeros_like(y)
            # perm = torch.randperm(len(idxs))
            # idxs = idxs[perm[:point_count]]
            # y_outer[idxs[:,0], idxs[:,1], idxs[:,2]] = 1

            # if datamode == DataModes.TRAINING:
            #     D,H,W = y.shape

            #     start = None
            #     end = None
            #     for k in range(D):
            #         if start is None and torch.sum(y[k]) > 0:
            #             start = k

            #         if start is not None and end is None and torch.sum(y[k]) == 0:
            #             end = k

            #     slc = random.randint(start,end)

            #     y_outer = torch.zeros_like(y)
            #     y_outer[slc] = 1

            #     temp = torch.zeros_like(y)
            #     temp[slc] = y[slc]
            #     y = temp

            # else:
            #     y_outer = y

            # io.imsave('/cvlabdata1/cvlab/datasets_udaranga/y.tif', 255*np.uint8(data['y_voxels'].data.cpu().numpy()))
            # io.imsave('/cvlabdata1/cvlab/datasets_udaranga/y.tif', 255*np.uint8(y))
            # io.imsave('/cvlabdata1/cvlab/datasets_udaranga/y_outer.tif', 255*np.uint8(y_outer))
            # io.imsave('/cvlabdata1/cvlab/datasets_udaranga/check.tif', 255*np.uint8(y_outer))

            # y_outer = sample_outer_surface_in_voxel(y)
            # y_outer = torch.nonzero(y_outer)
            # y_outer = torch.flip(y_outer, dims=[1]) # x,y,z
            # y_outer = normalize_vertices(y_outer, shape)

            # perm = torch.randperm(len(y_outer))
            # point_count = 500
            # y_outer = y_outer[perm[:np.min([len(perm), point_count])]]  # randomly pick 3000 points

            x_super_res = torch.tensor(1)
            y_super_res = torch.tensor(1)

        # w = torch.zeros_like(y)
        w = torch.tensor(1)

        # y_dst = ndimage.distance_transform_edt(1-y_outer.data.cpu().numpy()) #/ center[0]
        # y_dst = torch.from_numpy(y_dst).float()[None].cuda()

        new_samples += [
            SamplePlus(x.cpu(),
                       y.cpu(),
                       y_outer=y_outer.cpu(),
                       w=w.cpu(),
                       x_super_res=x_super_res.cpu(),
                       y_super_res=y_super_res.cpu(),
                       shape=shape)
        ]

    return new_samples
Beispiel #12
0
    def __getitem__(self, idx):
        item = self.data[idx]
        while True:
            x = torch.from_numpy(item.x).cuda()[None]
            y = torch.from_numpy(item.y).cuda().long()
            # y[y == 2] = 0 ## now y==2 means inside points
            y[y == 3] = 0
            # y[y==3] = 1
            if self.base_sparse_plane is not None:
                base_plane = torch.from_numpy(self.base_sparse_plane[idx]).cuda().float()
            else:
                base_plane = torch.ones_like(y).float()
            # breakpoint()
            C, D, H, W = x.shape
            center = (D//2, H//2, W//2)
            y = y.long()

            if self.mode == DataModes.TRAINING_EXTENDED: # if training do augmentation

                orientation = torch.tensor([0, -1, 0]).float()
                new_orientation = (torch.rand(3) - 0.5) * 2 * np.pi
                # new_orientation[2] = new_orientation[2] * 0 # no rotation outside x-y plane
                new_orientation = F.normalize(new_orientation, dim=0)
                q = orientation + new_orientation
                q = F.normalize(q, dim=0)
                theta_rotate = stns.stn_quaternion_rotations(q)

                shift = torch.tensor([d / (D // 2) for d, D in zip(2 * (torch.rand(3) - 0.5) * self.cfg.augmentation_shift_range, y.shape)])
                theta_shift = stns.shift(shift)

                f = 0.1
                scale = 1.0 - 2 * f *(torch.rand(1) - 0.5)
                theta_scale = stns.scale(scale)

                theta = theta_rotate @ theta_shift @ theta_scale

                x, y, base_plane = stns.transform(theta, x, y, w=base_plane)
            else:
                pose = torch.zeros(6).cuda()
                # w = torch.zeros_like(y)
                # base_plane = torch.ones_like(y)
                theta = torch.eye(4).cuda()

            x_super_res = torch.tensor(1)
            y_super_res = torch.tensor(1)

            x = crop(x, (C,) + self.cfg.patch_shape, (0,) + center)
            y = crop(y, self.cfg.patch_shape, center)
            base_plane = crop(base_plane, self.cfg.patch_shape, center)


            ## change for model_id = 4
            if self.point_model is not None:
                surface_points = torch.nonzero((y == 1))
                y_outer = torch.zeros_like(y)
                y_outer[surface_points[:, 0], surface_points[:, 1], surface_points[:, 2]] = 1
                y[y == 2] = 1

            surface_points_normalized_all = []
            vertices_mc_all = []
            faces_mc_all = []

            for i in range(1, self.cfg.num_classes):
                shape = torch.tensor(y.shape)[None].float()
                if self.mode != DataModes.TRAINING_EXTENDED:
                    gap = 1
                    y_ = clean_border_pixels((y==i).long(), gap=gap)
                    vertices_mc, faces_mc = voxel2mesh(y_, gap, shape)
                    vertices_mc_all += [vertices_mc]
                    faces_mc_all += [faces_mc]


                sphere_vertices = self.cfg.sphere_vertices
                atlas_faces = self.cfg.sphere_faces
                # self.sphere_vertices = sphere_ssvertices.repeat(self.config.config.batch_size,1,1).float()

                p = torch.acos(sphere_vertices[:,2]).cuda()
                t = torch.atan2(sphere_vertices[:,1], sphere_vertices[:,0]).cuda()
                p = torch.tensor(p, requires_grad=True)
                t = torch.tensor(t, requires_grad=True)

                ## change for model_id = 4
                if self.point_model is None:

                    y_outer = sample_outer_surface_in_voxel((y==i).long())
                    surface_points = torch.nonzero(y_outer)

                surface_points = torch.flip(surface_points, dims=[1]).float()  # convert z,y,x -> x, y, z
                surface_points_normalized = normalize_vertices(surface_points, shape)
                # surface_points_normalized = y_outer

                # perm = torch.randperm(len(surface_points_normalized))
                N = len(surface_points_normalized)

                surface_points_normalized_all += [surface_points_normalized.cuda()]
            if N > 0:
                break
            else:
                print("re-applying deformation coz N=0")

        # print('in')
        # breakpoint()
        if self.mode == DataModes.TRAINING_EXTENDED:
            return {   'x': x,
                       'faces_atlas': atlas_faces,
                       'y_voxels': y,
                       'surface_points': surface_points_normalized_all,
                       'p':p,
                       't':t,
                       'unpool':self.cfg.unpool_indices,
                       'w': y_outer,
                       'theta': theta.inverse()[:3],
                       'base_plane' : base_plane
                    }
        else:
            return {   'x': x,
                       'x_super_res': x_super_res,
                       'faces_atlas': atlas_faces,
                       'y_voxels': y,
                       'y_voxels_super_res': y_super_res,
                       'vertices_mc': vertices_mc_all,
                       'faces_mc': faces_mc_all,
                       'surface_points': surface_points_normalized_all,
                       'p':p,
                       't':t,
                       'unpool':[0, 1, 0, 1, 1],
                       'theta': theta.inverse()[:3],
                       'base_plane': base_plane
                    }