Esempio n. 1
0
def get_item(item, mode, config):

    x = item.x.cuda()[None]
    y = item.y.cuda()
    y_outer = item.y_outer.cuda()
    shape = item.shape

    # augmentation done only during training
    if mode == DataModes.TRAINING:  # if training do augmentation
        if torch.rand(1)[0] > 0.5:
            x = x.permute([0, 1, 3, 2])
            y = y.permute([0, 2, 1])
            y_outer = y_outer.permute([0, 2, 1])

        if torch.rand(1)[0] > 0.5:
            x = torch.flip(x, dims=[1])
            y = torch.flip(y, dims=[0])
            y_outer = torch.flip(y_outer, dims=[0])

        if torch.rand(1)[0] > 0.5:
            x = torch.flip(x, dims=[2])
            y = torch.flip(y, dims=[1])
            y_outer = torch.flip(y_outer, dims=[1])

        if torch.rand(1)[0] > 0.5:
            x = torch.flip(x, dims=[3])
            y = torch.flip(y, dims=[2])
            y_outer = torch.flip(y_outer, dims=[2])

        orientation = torch.tensor([0, -1, 0]).float()
        new_orientation = (torch.rand(3) - 0.5) * 2 * np.pi
        new_orientation = F.normalize(new_orientation, dim=0)
        q = orientation + new_orientation
        q = F.normalize(q, dim=0)
        theta_rotate = stns.stn_quaternion_rotations(q)

        shift = torch.tensor([
            d / (D // 2) for d, D in zip(
                2 * (torch.rand(3) - 0.5) *
                config.augmentation_shift_range, y.shape)
        ])
        theta_shift = stns.shift(shift)

        f = 0.1
        scale = 1.0 - 2 * f * (torch.rand(1) - 0.5)
        theta_scale = stns.scale(scale)

        theta = theta_rotate @ theta_shift @ theta_scale

        x, y, y_outer = stns.transform(theta, x, y, y_outer)

    surface_points_normalized_all = []
    vertices_mc_all = []
    faces_mc_all = []
    for i in range(1, config.num_classes):
        shape = torch.tensor(y.shape)[None].float()
        if mode != DataModes.TRAINING:
            gap = 1
            y_ = clean_border_pixels((y == i).long(), gap=gap)
            vertices_mc, faces_mc = voxel2mesh(y_, gap, shape)
            vertices_mc_all += [vertices_mc]
            faces_mc_all += [faces_mc]

        y_outer = sample_outer_surface_in_voxel((y == i).long())
        surface_points = torch.nonzero(y_outer)
        surface_points = torch.flip(
            surface_points, dims=[1]).float()  # convert z,y,x -> x, y, z
        surface_points_normalized = normalize_vertices(surface_points, shape)
        # surface_points_normalized = y_outer

        perm = torch.randperm(len(surface_points_normalized))
        point_count = 3000
        surface_points_normalized_all += [
            surface_points_normalized[
                perm[:np.min([len(perm), point_count])]].cuda()
        ]  # randomly pick 3000 points

    if mode == DataModes.TRAINING:
        return {
            'x': x,
            'y_voxels': y,
            'surface_points': surface_points_normalized_all,
            'unpool': [0, 1, 0, 1, 0]
        }
    else:
        return {
            'x': x,
            'y_voxels': y,
            'vertices_mc': vertices_mc_all,
            'faces_mc': faces_mc_all,
            'surface_points': surface_points_normalized_all,
            'unpool': [0, 1, 1, 1, 1]
        }
Esempio n. 2
0
def get_item_(item, mode, config):
    x = item.x.cuda()[None]
    y = item.y.cuda()
    # x = y[None] # <<<<<<<<<<<<< comment
    y_outer = item.y_outer.cuda()
    w = item.w.cuda()
    x_super_res = item.x_super_res[None]
    y_super_res = item.y_super_res
    shape = item.shape

    # print('in')
    # x_temp = x.clone()
    # y_temp = y.clone()

    # embed()
    # x = x_temp
    # y = y_temp
    surface_points = y_outer
    # surface_points_before = torch.nonzero(y_outer)
    # surface_points_before = torch.flip(surface_points_before, dims=[1])
    # io.imsave('/cvlabdata2/cvlab/datasets_udaranga/check300.tif', np.uint8(y_outer.data.cpu().numpy() * 255))

    # surface_points = y_outer
    # augmentation done only during training
    if mode == DataModes.TRAINING_EXTENDED:  # if training do augmentation
        if torch.rand(1)[0] > 0.0:
            x = x.permute([0, 1, 3, 2])
            y = y.permute([0, 2, 1])
            surface_points = surface_points[:, [1, 0, 2]]

        if torch.rand(1)[0] > 0.5:
            x = torch.flip(x, dims=[1])
            y = torch.flip(y, dims=[0])
            surface_points[:, 2] = -surface_points[:, 2]

        if torch.rand(1)[0] > 0.5:
            x = torch.flip(x, dims=[2])
            y = torch.flip(y, dims=[1])
            surface_points[:, 1] = -surface_points[:, 1]

        if torch.rand(1)[0] > 0.5:
            x = torch.flip(x, dims=[3])
            y = torch.flip(y, dims=[2])
            surface_points[:, 0] = -surface_points[:, 0]

        orientation = torch.tensor([0, -1, 0]).float()
        new_orientation = (torch.rand(3) - 0.5) * 2 * np.pi
        new_orientation = F.normalize(new_orientation, dim=0)
        q = orientation + new_orientation
        q = F.normalize(q, dim=0)
        theta_rotate = stns.stn_quaternion_rotations(q)

        shift = torch.tensor([
            d / (D // 2) for d, D in zip(
                2 * (torch.rand(3) - 0.5) *
                config.augmentation_shift_range, y.shape)
        ])
        D, H, W = y.shape
        # shift = torch.tensor([10,15,20]).float() / D
        theta_shift = stns.shift(shift)

        f = 0.1
        scale = 1.0 - 2 * f * (torch.rand(1) - 0.5)
        theta_scale = stns.scale(scale)

        theta = theta_rotate @ theta_shift @ theta_scale

        x, y, w = stns.transform(theta, x, y, w)

        # not necessary during training
        x_super_res = None
        y_super_res = None

        # io.imsave('/cvlabdata2/cvlab/datasets_udaranga/check307.tif', np.uint8(y_outer_grid_sampler.data.cpu().numpy() * 255))

        # theta_shift = stns.shift(torch.tensor([10,0,0]))

        # surface_points_after = surface_points_before.float() - (shape.cuda()-1)/2

        theta_inv = theta_scale.inverse() @ theta_shift.inverse(
        ) @ theta_rotate.inverse()
        # theta_inv = theta_rotate
        theta_inv = theta_inv[:3]
        surface_points = torch.cat(
            [surface_points,
             torch.ones(len(surface_points), 1).cuda()], dim=1)
        surface_points = theta_inv.cuda() @ surface_points.float().permute(
            1, 0)
        surface_points = surface_points.permute(1, 0)
        # surface_points_after = surface_points_after.float() @ theta_rotate_1.cuda()
        # surface_points_after = surface_points_after + (shape.cuda()-1)/2
        # surface_points = torch.round(surface_points_after).long()

    # embed()
    # print('{} {}'.format(torch.any(surface_points<-1), torch.any(surface_points>1)), end='')
    surface_points = surface_points[torch.all(surface_points > -1, dim=1) *
                                    torch.all(surface_points < 1, dim=1)]
    # print(' | {} {}'.format(torch.any(surface_points<-1), torch.any(surface_points>1)))

    # vertices_ = torch.round((surface_points + 1)*63.0/2).long()
    # y_outer_ = torch.zeros_like(y)
    # y_outer_[vertices_[:,2], vertices_[:,1], vertices_[:,0]] = 1
    # y_outer_ = y_outer_ + 3*y

    # x_ = (x - x.min())/(x.max()-x.min())
    # overlay_y_hat = blend_cpu(x_[0].cpu(), y_outer_.cpu(), 4)
    # x_ = x_[0]
    # x_ = 255*x_[:,:,:,None].repeat(1,1,1,3).cpu()
    # overlay = np.concatenate([x_, overlay_y_hat], axis=2)
    # io.imsave('/cvlabdata2/cvlab/datasets_udaranga/check_{}.tif'.format(int(torch.rand(1)*10000)), np.uint8(overlay))

    # print(crash)
    gap = 1
    y_ = clean_border_pixels(y, gap=gap)
    vertices_mc, faces_mc = voxel2mesh(y_, gap,
                                       torch.tensor(y.shape)[None].float())

    sphere_vertices = config.sphere_vertices
    sphere_faces = config.sphere_faces
    # self.sphere_vertices = sphere_vertices.repeat(self.config.config.batch_size,1,1).float()

    p = torch.acos(sphere_vertices[:, 2])
    t = torch.atan2(sphere_vertices[:, 1], sphere_vertices[:, 0])
    p = torch.tensor(p, requires_grad=True)
    t = torch.tensor(t, requires_grad=True)

    # # points on sphere
    # x_ = torch.sin(p)*torch.cos(t)
    # y_ = torch.sin(p)*torch.sin(t)
    # z_ = torch.cos(p)
    # atlas_vertices = torch.cat([x_[:,None],y_[:,None],z_[:,None]],dim=1).float()

    # surface_points = torch.nonzero(y_outer)

    # surface_points = normalize_vertices(surface_points, shape)

    if mode == DataModes.TRAINING_EXTENDED:
        return {
            'x': x,
            'faces_atlas': sphere_faces,
            'y_voxels': y,
            'surface_points': surface_points,
            'p': p,
            't': t,
            'unpool': config.unpool_indices
        }
    else:
        return {
            'x': x,
            'x_super_res': x_super_res,
            'faces_atlas': sphere_faces,
            'y_voxels': y,
            'y_voxels_super_res': y_super_res,
            'vertices_mc': vertices_mc,
            'faces_mc': faces_mc,
            'surface_points': surface_points,
            'p': p,
            't': t,
            'unpool': [1, 1, 1, 0, 0]
        }
Esempio n. 3
0
def get_item__(item, mode, config):

    x = item.x.cuda()[None]
    y = item.y.cuda()
    # x = y[None].float() # <<<<<<<<<<<<< comment
    y_outer = item.y_outer.cuda()
    w = item.w.cuda()
    x_super_res = item.x_super_res[None]
    y_super_res = item.y_super_res
    # x_super_res = y_super_res[None].float() # <<<<<<<<<<<<< comment
    shape = item.shape

    # augmentation done only during training
    if mode == DataModes.TRAINING_EXTENDED:  # if training do augmentation
        if torch.rand(1)[0] > 0.5:
            x = x.permute([0, 1, 3, 2])
            y = y.permute([0, 2, 1])
            y_outer = y_outer.permute([0, 2, 1])

        if torch.rand(1)[0] > 0.5:
            x = torch.flip(x, dims=[1])
            y = torch.flip(y, dims=[0])
            y_outer = torch.flip(y_outer, dims=[0])

        if torch.rand(1)[0] > 0.5:
            x = torch.flip(x, dims=[2])
            y = torch.flip(y, dims=[1])
            y_outer = torch.flip(y_outer, dims=[1])

        if torch.rand(1)[0] > 0.5:
            x = torch.flip(x, dims=[3])
            y = torch.flip(y, dims=[2])
            y_outer = torch.flip(y_outer, dims=[2])

        orientation = torch.tensor([0, -1, 0]).float()
        new_orientation = (torch.rand(3) - 0.5) * 2 * np.pi
        # new_orientation[2] = new_orientation[2] * 0 # no rotation outside x-y plane
        new_orientation = F.normalize(new_orientation, dim=0)
        q = orientation + new_orientation
        q = F.normalize(q, dim=0)
        theta_rotate = stns.stn_quaternion_rotations(q)

        shift = torch.tensor([
            d / (D // 2) for d, D in zip(
                2 * (torch.rand(3) - 0.5) *
                config.augmentation_shift_range, y.shape)
        ])
        theta_shift = stns.shift(shift)

        f = 0.1
        scale = 1.0 - 2 * f * (torch.rand(1) - 0.5)
        theta_scale = stns.scale(scale)

        theta = theta_rotate @ theta_shift @ theta_scale

        # x, y = stns.transform(theta, x, y)
        x, y, y_outer = stns.transform(theta, x, y, y_outer)

        # not necessary during training
        x_super_res = None
        y_super_res = None

    # y_outer = sample_outer_surface_in_voxel(y)
    if mode != DataModes.TRAINING_EXTENDED:
        gap = 1
        y_ = clean_border_pixels(y, gap=gap)
        vertices_mc, faces_mc = voxel2mesh(y_, gap, shape)

    sphere_vertices = config.sphere_vertices
    atlas_faces = config.sphere_faces
    # self.sphere_vertices = sphere_vertices.repeat(self.config.config.batch_size,1,1).float()

    p = torch.acos(sphere_vertices[:, 2]).cuda()
    t = torch.atan2(sphere_vertices[:, 1], sphere_vertices[:, 0]).cuda()
    p = torch.tensor(p, requires_grad=True)
    t = torch.tensor(t, requires_grad=True)

    surface_points = torch.nonzero(y_outer)
    surface_points = torch.flip(surface_points,
                                dims=[1]).float()  # convert z,y,x -> x, y, z
    surface_points_normalized = normalize_vertices(surface_points, shape)
    # surface_points_normalized = y_outer

    perm = torch.randperm(len(surface_points_normalized))
    point_count = 3000
    surface_points_normalized = surface_points_normalized[
        perm[:np.min([len(perm), point_count])]]  # randomly pick 3000 points

    if mode == DataModes.TRAINING_EXTENDED:
        return {
            'x': x,
            'faces_atlas': atlas_faces,
            'y_voxels': y,
            'surface_points': surface_points_normalized,
            'p': p,
            't': t,
            'unpool': config.unpool_indices,
            'w': y_outer
        }
    else:
        return {
            'x': x,
            'x_super_res': x_super_res,
            'faces_atlas': atlas_faces,
            'y_voxels': y,
            'y_voxels_super_res': y_super_res,
            'vertices_mc': vertices_mc,
            'faces_mc': faces_mc,
            'surface_points': surface_points_normalized,
            'p': p,
            't': t,
            'unpool': [0, 1, 0, 1, 0]
        }
Esempio n. 4
0
    def __getitem__(self, idx):
        item = self.data[idx]
        while True:
            x = torch.from_numpy(item.x).cuda()[None]
            y = torch.from_numpy(item.y).cuda().long()
            # y[y == 2] = 0 ## now y==2 means inside points
            y[y == 3] = 0
            # y[y==3] = 1
            if self.base_sparse_plane is not None:
                base_plane = torch.from_numpy(self.base_sparse_plane[idx]).cuda().float()
            else:
                base_plane = torch.ones_like(y).float()
            # breakpoint()
            C, D, H, W = x.shape
            center = (D//2, H//2, W//2)
            y = y.long()

            if self.mode == DataModes.TRAINING_EXTENDED: # if training do augmentation

                orientation = torch.tensor([0, -1, 0]).float()
                new_orientation = (torch.rand(3) - 0.5) * 2 * np.pi
                # new_orientation[2] = new_orientation[2] * 0 # no rotation outside x-y plane
                new_orientation = F.normalize(new_orientation, dim=0)
                q = orientation + new_orientation
                q = F.normalize(q, dim=0)
                theta_rotate = stns.stn_quaternion_rotations(q)

                shift = torch.tensor([d / (D // 2) for d, D in zip(2 * (torch.rand(3) - 0.5) * self.cfg.augmentation_shift_range, y.shape)])
                theta_shift = stns.shift(shift)

                f = 0.1
                scale = 1.0 - 2 * f *(torch.rand(1) - 0.5)
                theta_scale = stns.scale(scale)

                theta = theta_rotate @ theta_shift @ theta_scale

                x, y, base_plane = stns.transform(theta, x, y, w=base_plane)
            else:
                pose = torch.zeros(6).cuda()
                # w = torch.zeros_like(y)
                # base_plane = torch.ones_like(y)
                theta = torch.eye(4).cuda()

            x_super_res = torch.tensor(1)
            y_super_res = torch.tensor(1)

            x = crop(x, (C,) + self.cfg.patch_shape, (0,) + center)
            y = crop(y, self.cfg.patch_shape, center)
            base_plane = crop(base_plane, self.cfg.patch_shape, center)


            ## change for model_id = 4
            if self.point_model is not None:
                surface_points = torch.nonzero((y == 1))
                y_outer = torch.zeros_like(y)
                y_outer[surface_points[:, 0], surface_points[:, 1], surface_points[:, 2]] = 1
                y[y == 2] = 1

            surface_points_normalized_all = []
            vertices_mc_all = []
            faces_mc_all = []

            for i in range(1, self.cfg.num_classes):
                shape = torch.tensor(y.shape)[None].float()
                if self.mode != DataModes.TRAINING_EXTENDED:
                    gap = 1
                    y_ = clean_border_pixels((y==i).long(), gap=gap)
                    vertices_mc, faces_mc = voxel2mesh(y_, gap, shape)
                    vertices_mc_all += [vertices_mc]
                    faces_mc_all += [faces_mc]


                sphere_vertices = self.cfg.sphere_vertices
                atlas_faces = self.cfg.sphere_faces
                # self.sphere_vertices = sphere_ssvertices.repeat(self.config.config.batch_size,1,1).float()

                p = torch.acos(sphere_vertices[:,2]).cuda()
                t = torch.atan2(sphere_vertices[:,1], sphere_vertices[:,0]).cuda()
                p = torch.tensor(p, requires_grad=True)
                t = torch.tensor(t, requires_grad=True)

                ## change for model_id = 4
                if self.point_model is None:

                    y_outer = sample_outer_surface_in_voxel((y==i).long())
                    surface_points = torch.nonzero(y_outer)

                surface_points = torch.flip(surface_points, dims=[1]).float()  # convert z,y,x -> x, y, z
                surface_points_normalized = normalize_vertices(surface_points, shape)
                # surface_points_normalized = y_outer

                # perm = torch.randperm(len(surface_points_normalized))
                N = len(surface_points_normalized)

                surface_points_normalized_all += [surface_points_normalized.cuda()]
            if N > 0:
                break
            else:
                print("re-applying deformation coz N=0")

        # print('in')
        # breakpoint()
        if self.mode == DataModes.TRAINING_EXTENDED:
            return {   'x': x,
                       'faces_atlas': atlas_faces,
                       'y_voxels': y,
                       'surface_points': surface_points_normalized_all,
                       'p':p,
                       't':t,
                       'unpool':self.cfg.unpool_indices,
                       'w': y_outer,
                       'theta': theta.inverse()[:3],
                       'base_plane' : base_plane
                    }
        else:
            return {   'x': x,
                       'x_super_res': x_super_res,
                       'faces_atlas': atlas_faces,
                       'y_voxels': y,
                       'y_voxels_super_res': y_super_res,
                       'vertices_mc': vertices_mc_all,
                       'faces_mc': faces_mc_all,
                       'surface_points': surface_points_normalized_all,
                       'p':p,
                       't':t,
                       'unpool':[0, 1, 0, 1, 1],
                       'theta': theta.inverse()[:3],
                       'base_plane': base_plane
                    }