-
Notifications
You must be signed in to change notification settings - Fork 0
/
selectors.py
137 lines (99 loc) · 4.38 KB
/
selectors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Select indices from a matplotlib collection using `LassoSelector`.
Selected indices are saved in the `ind` attribute. This tool highlights
selected points by fading them out (i.e., reducing their alpha values).
If your collection has alpha < 1, this tool will permanently alter them.
Note that this tool selects collection objects based on their *origins*
(i.e., `offsets`).
Parameters
----------
ax : :class:`~matplotlib.axes.Axes`
Axes to interact with.
collection : :class:`matplotlib.collections.Collection` subclass
Collection you want to select from.
alpha_other : 0 <= float <= 1
To highlight a selection, this tool sets all selected points to an
alpha value of 1 and non-selected points to `alpha_other`.
"""
import numpy as np
from matplotlib.widgets import LassoSelector
from matplotlib.path import Path
import vtk
class SelectFromCollection(object):
def __init__(self, viewer, ax, collection, alpha_other=0.3):
self.canvas = ax.figure.canvas
self.collection = collection
self.viewer = viewer
self.alpha_other = alpha_other
self.xys = collection.get_offsets()
self.Npts = len(self.xys)
# Ensure that we have separate colors for each object
self.fc = collection.get_facecolors()
#self.collection.set_facecolors('red')
if len(self.fc) == 0:
raise ValueError('Collection must have a facecolor')
elif len(self.fc) == 1:
self.fc = np.tile(self.fc, self.Npts).reshape(self.Npts, -1)
self.lasso = LassoSelector(ax, onselect=self.onselect)
self.ind = []
def onselect(self, verts):
voxels_xyz = []
path = Path(verts)
self.ind = np.nonzero([path.contains_point(xy) for xy in self.xys])[0]
for i in self.ind:
voxel_xyz = self.viewer.coordinates[i]
voxel_xyz.append(self.viewer.meanx[i])
voxels_xyz.append(voxel_xyz)
# light blue: 0.1, 1, 1
self.fc[:, 0] = 0
self.fc[:, 1] = 0
self.fc[:, 2] = 1
self.fc[:, -1] = self.alpha_other
self.fc[self.ind, 0] = 1
self.fc[self.ind, 1] = 1
self.fc[self.ind, 2] = 0
self.fc[self.ind, -1] = 1
self.collection.set_facecolors(self.fc)
self.canvas.draw_idle()
#self.canvas.draw()
self.highlight_voxels_2D(voxels_xyz)
#self.canvas.draw_idle()
self.highlight_voxels_3D(voxels_xyz)
def highlight_voxels_2D(self, coords):
newimage = vtk.vtkImageData()
newimage.SetSpacing(self.viewer.doseplans["p1"].GetSpacing())
newimage.SetOrigin(self.viewer.doseplans["p1"].GetOrigin())
newimage.SetDimensions(self.viewer.doseplans["p1"].GetDimensions())
newimage.SetExtent(self.viewer.doseplans["p1"].GetExtent())
newimage.SetNumberOfScalarComponents(1)
newimage.SetScalarTypeToDouble()
newimage.AllocateScalars()
for p in coords:
newimage.SetScalarComponentFromDouble(p[0],p[1],p[2],0, 60)
flipYFilter = vtk.vtkImageFlip()
flipYFilter.SetFilteredAxis(1)
flipYFilter.SetInput(newimage)
flipYFilter.Update()
self.viewer.refresh_2d(flipYFilter.GetOutput())
def highlight_voxels_3D(self, coords):
self.viewer.ren_iso.RemoveVolume(self.viewer.vol)
newimage = vtk.vtkImageData()
newimage.SetSpacing(self.viewer.volumedata.GetSpacing())
newimage.SetOrigin(self.viewer.volumedata.GetOrigin())
newimage.SetDimensions(self.viewer.volumedata.GetDimensions())
newimage.SetExtent(self.viewer.volumedata.GetExtent())
newimage.SetNumberOfScalarComponents(1)
newimage.SetScalarTypeToDouble()
newimage.AllocateScalars()
for p in coords:
newimage.SetScalarComponentFromDouble(p[0],p[1],p[2],0, p[3])
shift_scale = vtk.vtkImageShiftScale()
shift_scale.SetInput(newimage)
shift_scale.SetOutputScalarTypeToUnsignedChar()
shift_scale.Update()
flipYFilter = vtk.vtkImageFlip()
flipYFilter.SetFilteredAxis(1)
flipYFilter.SetInput(shift_scale.GetOutput())
flipYFilter.Update()
self.viewer.volMapper.SetInput(flipYFilter.GetOutput())
self.viewer.ren_iso.AddVolume(self.viewer.vol)
self.viewer.refresh_3d()