def reset_dataviews(self): #in principle it might be useful to do some more cleanup here self.display_mode = 'normal' self.dv_3d = DVMayavi(self) self.dv_mat = DVMatrix(self) self.dv_circ = DVCircle(self) self.chg_scalar_colorbar()
def __init__(self, name, lab_pos, labnam, srf, labv, gui=None, adj=None, soft_max_edges=20000, **kwargs): super(Dataset, self).__init__(**kwargs) self.gui = gui self.name = name self.opts = DisplayOptions(self) self.scalar_display_settings = ScalarDisplaySettings(self) #this is effectively load_parc self.lab_pos = lab_pos self.labnam = labnam self.srf = srf self.labv = labv self.nr_labels = len(labnam) #load_parc redundantly sets the current display but oh well. #self.load_parc(lab_pos,labnam,srf,labv, # init_display.subject_name,init_display.parc_name) #if adj is None, it means it will be guaranteed to be supplied later #by the user #this is load adj, except without initializing nonexistent dataviews if adj is not None: self.adj = adj self.soft_max_edges = soft_max_edges self.pos_helper_gen() #flip adj ord should already be done to the preprocessed adj self.adj_helper_gen() self.color_legend = ColorLegend() self.node_colors_gen() self.dv_3d = DVMayavi(self) self.dv_mat = DVMatrix(self) self.dv_circ = DVCircle(self) self.chg_scalar_colorbar()
def reset_dataviews(self): #in principle it might be useful to do some more cleanup here self.display_mode='normal' self.dv_3d=DVMayavi(self) self.dv_mat=DVMatrix(self) self.dv_circ=DVCircle(self) self.chg_scalar_colorbar()
def __init__(self,name,lab_pos,labnam,srf,labv, gui=None,adj=None,soft_max_edges=20000,**kwargs): super(Dataset,self).__init__(**kwargs) self.gui=gui self.name=name self.opts=DisplayOptions(self) self.scalar_display_settings=ScalarDisplaySettings(self) #this is effectively load_parc self.lab_pos=lab_pos self.labnam=labnam self.srf=srf self.labv=labv self.nr_labels=len(labnam) #load_parc redundantly sets the current display but oh well. #self.load_parc(lab_pos,labnam,srf,labv, # init_display.subject_name,init_display.parc_name) #if adj is None, it means it will be guaranteed to be supplied later #by the user #this is load adj, except without initializing nonexistent dataviews if adj is not None: self.adj=adj self.soft_max_edges=soft_max_edges self.pos_helper_gen() #flip adj ord should already be done to the preprocessed adj self.adj_helper_gen() self.color_legend=ColorLegend() self.node_colors_gen() self.dv_3d=DVMayavi(self) self.dv_mat=DVMatrix(self) self.dv_circ=DVCircle(self) self.chg_scalar_colorbar()
class Dataset(HasTraits): ######################################################################## # FUNDAMENTALLY NECESSARY DATA ######################################################################## name=Str #give this dataset a name gui=Any #symbolic reference to a modular cvu nr_labels=Int nr_edges=Int labnam=List(Str) #adjlabfile=File #the adjlabfile is not needed. this is only kept on hand to pass it #around if specified as CLI arg. it is only used upon loading an adjmat. #so just have adjmat loading be a part of dataset creation and get rid of #keeping track of this adj=Any #NxN np.ndarray adj_thresdiag=Property(depends_on='adj') #NxN np.ndarray @cached_property def _get_adj_thresdiag(self): adjt = self.adj.copy() adjt[np.where(np.eye(self.nr_labels))]=np.min(adjt[np.where(adjt)]) return adjt starts=Any #Ex3 np.ndarray vecs=Any #Ex3 np.ndarray edges=Any #Ex2 np.ndarray(int) srf=Instance(SurfData) #labv=List(Instance(mne.Label)) #all that is needed from this is a map of name->vertex #this is a considerable portion of the data contained in a label but still #only perhaps 15%. To make lightweight, extract this from labv #TODO convert it to that labv=Dict #is an OrderedDict in parcellation order lab_pos=Any #Nx3 np.ndarray ######################################################################### # CRITICAL NONADJUSTABLE DATA WITHOUT WHICH DISPLAY CANNOT EXIST ######################################################################### dv_3d = Either(Instance(DataView),None) dv_mat = Either(Instance(DataView),None) dv_circ = Either(Instance(DataView),None) soft_max_edges=Int adjdat=Any #Ex1 np.ndarray left=Any #Nx1 np.ndarray(bool) right=Any #Nx1 np.ndarray(bool) interhemi=Any #Nx1 np.ndarray(bool) masked=Any #Nx1 np.ndarray(bool) lhnodes=Property(depends_on='labnam') #Nx1 np.ndarray(int) rhnodes=Property(depends_on='labnam') #Nx1 np.ndarray(int) @cached_property def _get_lhnodes(self): return np.where(map(lambda r:r[0]=='l',self.labnam))[0] @cached_property def _get_rhnodes(self): return np.where(map(lambda r:r[0]=='r',self.labnam))[0] #TODO node_colors = Any #Nx3 np.ndarray #node_colors represents the colors held by the nodes. the current value of #node_colors depends on the current policy (i.e. the current display mode). #however, don't take this all too literally. depending on the current #policy, the dataviews may choose to ignore what is in node_colors and #use some different color. #this is always true of Mayavi views, who can't use the node colors at all. #because mayavi doesn't play nice with true colors (this could be fixed #if mayavi is fixed). it is also true of the other plots in scalar mode, #but when scalars are not specified for those dataviews. #node_colors_default will be set uniquely for each parcellation and can #thus be different for different datasets. #group_colors is more subtle; it can in principle be set uniquely for each #parcellation as long as the parcellations don't conform to aparc. for #instance, destrieux parc has different group colors. right now i'm a long #way away from dealing with this but i think in a month it will be prudent #to just have the dataset capture both of these variables node_colors_default=List node_labels_numberless=List(Str) group_colors=List nr_groups=Int group_labels=List(Str) color_legend=Instance(ColorLegend) module_colors=List default_glass_brain_color=Constant((.82,.82,.82)) ######################################################################### # ASSOCIATED STATISTICAL AND ANALYTICAL DATA ######################################################################### node_scalars=Dict scalar_display_settings=Instance(ScalarDisplaySettings) #TODO make modules a dictionary modules=List nr_modules=Int graph_stats=Dict ######################################################################### # ASSOCIATED DISPLAY OPTIONS AND DISPLAY STATE (ADJUSTABLE/TRANSIENT) ######################################################################### opts=Instance(DisplayOptions) display_mode=Enum('normal','scalar','module_single','module_multi') reset_thresh=Property(Method) def _get_reset_thresh(self): if self.opts.thresh_type=='prop': return self.prop_thresh elif self.opts.thresh_type=='abs': return self.abs_thresh thresval=Float curr_node=Either(Int,None) cur_module=Either(Int,'custom',None) custom_module=List ######################################################################## # SETUP ######################################################################## def __init__(self,name,lab_pos,labnam,srf,labv, gui=None,adj=None,soft_max_edges=20000,**kwargs): super(Dataset,self).__init__(**kwargs) self.gui=gui self.name=name self.opts=DisplayOptions(self) self.scalar_display_settings=ScalarDisplaySettings(self) #this is effectively load_parc self.lab_pos=lab_pos self.labnam=labnam self.srf=srf self.labv=labv self.nr_labels=len(labnam) #load_parc redundantly sets the current display but oh well. #self.load_parc(lab_pos,labnam,srf,labv, # init_display.subject_name,init_display.parc_name) #if adj is None, it means it will be guaranteed to be supplied later #by the user #this is load adj, except without initializing nonexistent dataviews if adj is not None: self.adj=adj self.soft_max_edges=soft_max_edges self.pos_helper_gen() #flip adj ord should already be done to the preprocessed adj self.adj_helper_gen() self.color_legend=ColorLegend() self.node_colors_gen() self.dv_3d=DVMayavi(self) self.dv_mat=DVMatrix(self) self.dv_circ=DVCircle(self) self.chg_scalar_colorbar() def __str__(self): return self.name def __repr__(self): return self.name def __getitem__(self,key): if key==0: return self elif key==1: return self.name else: raise KeyError('Invalid indexing to dataset. Dataset indexing ' 'is implemented to appease CheckListEditor and can only be 0 or 1.') ######################################################################## # GEN METHODS ######################################################################## #preconditions: lab_pos has been set. def pos_helper_gen(self,reset_scalars=True): self.nr_labels = n = len(self.lab_pos) self.nr_edges = self.nr_labels*(self.nr_labels-1)//2 #self.starts = np.zeros((self.nr_edges,3),dtype=float) #self.vecs = np.zeros((self.nr_edges,3),dtype=float) #self.edges = np.zeros((self.nr_edges,2),dtype=int) #i=0 #for r2 in xrange(0,self.nr_labels,1): # for r1 in xrange(0,r2,1): #self.starts[i,:] = self.lab_pos[r1] #self.vecs[i,:] = self.lab_pos[r2]-self.lab_pos[r1] #self.edges[i,0],self.edges[i,1] = r1,r2 #i+=1 tri_ixes = np.triu(np.ones((n,n)),1) ixes, = np.where(tri_ixes.flat) A_r = np.tile(self.lab_pos,(n,1,1)) self.starts = np.reshape(A_r,(n*n,3))[ixes,:] self.vecs = np.reshape(A_r-np.transpose(A_r,(1,0,2)),(n*n,3))[ixes,:] self.edges = np.transpose(np.where(tri_ixes.T))[:,::-1] #pos_helper_gen is now only called from load adj. The reason it is #because it can change on all adj changes because of the soft #cap. The number of edges can differ between adjmats because of the #soft cap and all of the positions need to be recalculated if it does. #pos_helper_gen really only has to do with edge positions. Node and #surf positions dont depend on it at all. #TODO possibly, keep track of the soft cap and do nothing if it hasn't #changed #RESPONSE: yes but this check should be done in adj_load if reset_scalars: self.node_scalars = {} self.display_mode='normal' #precondition: adj_helper_gen() must be run after pos_helper_gen() def adj_helper_gen(self): self.nr_edges = self.nr_labels*(self.nr_labels-1)//2 self.adjdat = np.zeros((self.nr_edges),dtype=float) self.interhemi = np.zeros((self.nr_edges),dtype=bool) self.left = np.zeros((self.nr_edges),dtype=bool) self.right = np.zeros((self.nr_edges),dtype=bool) self.masked = np.zeros((self.nr_edges),dtype=bool) i=0 self.adj[xrange(self.nr_labels),xrange(self.nr_labels)]=0 #for r2 in xrange(0,self.nr_labels,1): #self.adj[r2][r2]=0 #for r1 in xrange(0,r2,1): #self.adjdat[i] = self.adj[r1][r2] #self.interhemi[i] = self.labnam[r1][0] != self.labnam[r2][0] #self.left[i] = self.labnam[r1][0]==self.labnam[r2][0]=='l' #self.right[i] = self.labnam[r1][0]==self.labnam[r2][0]=='r' #i+=1 n = self.nr_labels ixes, = np.where(np.triu(np.ones((n,n)),1).flat) self.adjdat = self.adj.flat[::-1][ixes][::-1] from parsing_utils import same_hemi sh=np.vectorize(same_hemi) L_r = np.tile(self.labnam,(self.nr_labels,1)) self.interhemi = np.logical_not(sh(L_r,L_r.T)).flat[::-1][ixes][::-1] self.left = sh(L_r,L_r.T,'l').flat[::-1][ixes][::-1] self.right = sh(L_r,L_r.T,'r').flat[::-1][ixes][::-1] #remove all but the soft_max_edges largest connections if self.nr_edges > self.soft_max_edges: cutoff = sorted(self.adjdat)[self.nr_edges-self.soft_max_edges-1] zi = np.where(self.adjdat>=cutoff) # if way way too many edges remain, make it a hard max # this happens in DTI data which is very sparse, the cutoff is 0 if len(zi[0])>(self.soft_max_edges+200): zi=np.where(self.adjdat>cutoff) self.starts=self.starts[zi[0],:] self.vecs=self.vecs[zi[0],:] self.edges=self.edges[zi[0],:] self.adjdat=self.adjdat[zi[0]] self.interhemi=self.interhemi[zi[0]] self.left=self.left[zi[0]] self.right=self.right[zi[0]] self.nr_edges=len(self.adjdat) self.verbose_msg(str(self.nr_edges)+" total connections") #sort the adjdat sort_idx=np.argsort(self.adjdat,axis=0) self.adjdat=self.adjdat[sort_idx].squeeze() self.edges=self.edges[sort_idx].squeeze() self.starts=self.starts[sort_idx].squeeze() self.vecs=self.vecs[sort_idx].squeeze() self.left=self.left[sort_idx].squeeze() self.right=self.right[sort_idx].squeeze() self.interhemi=self.interhemi[sort_idx].squeeze() self.masked=self.masked[sort_idx].squeeze() #just to prune #try to auto-set the threshold to a reasonable value if self.nr_edges < 500: self.opts.pthresh=.01 else: thr = (self.nr_edges - 500) / (self.nr_edges) self.opts.pthresh=thr self.opts.thresh_type = 'prop' self.display_mode='normal' def node_colors_gen(self): #node groups could change upon loading a new parcellation hi_contrast_clist= ('#26ed1a','#eaf60b','#e726f4','#002aff','#05d5d5', '#f4a5e0','#bbb27e','#641179','#068c40') hi_contrast_cmap=LinearSegmentedColormap.from_list('hi_contrast', hi_contrast_clist) #labels are assumed to start with lh_ and rh_ self.node_labels_numberless=map( lambda n:n.replace('div','').strip('1234567890_'),self.labnam) node_groups=map(lambda n:n[3:],self.node_labels_numberless) #put group names in ordered set n_set=set() self.group_labels=( [i for i in node_groups if i not in n_set and not n_set.add(i)]) self.nr_groups=len(self.group_labels) #get map of {node name -> node group} grp_ids=dict(zip(self.group_labels,xrange(self.nr_groups))) #group colors does not change unless the parcellation is reloaded self.group_colors=( [hi_contrast_cmap(i/self.nr_groups) for i in range(self.nr_groups)]) #node colors changes constantly, so copy and stash the result self.node_colors=map(lambda n:self.group_colors[grp_ids[n]],node_groups) self.node_colors_default=list(self.node_colors) #create the color legend associated with this dataset def create_color_legend_entry(zipped): label,color=zipped return LegendEntry(metaregion=label,col=color) self.color_legend.entries=map(create_color_legend_entry, zip(self.group_labels,self.group_colors)) #set up some colors that are acceptably high contrast for modules #this is unrelated to node colors in any way, for multi-module mode self.module_colors=( [[255,255,255,255],[204,0,0,255],[51,204,51,255],[66,0,204,255], [80,230,230,255],[51,153,255,255],[255,181,255,255], [255,163,71,255],[221,221,149,255],[183,230,46,255], [77,219,184,255],[255,255,204,255],[0,0,204,255],[204,69,153,255], [255,255,0,255],[0,128,0,255],[163,117,25,255],[255,25,117,255]]) ###################################################################### # DRAW METHODS ###################################################################### def draw(self): self.draw_surfs(); self.draw_nodes(); self.draw_conns() def draw_surfs(self): for data_view in (self.dv_3d, self.dv_mat, self.dv_circ): data_view.draw_surfs() def draw_nodes(self): self.set_node_colors() for data_view in (self.dv_3d, self.dv_mat, self.dv_circ): data_view.draw_nodes() def set_node_colors(self): #set node_colors if self.display_mode=='normal': self.node_colors=list(self.node_colors_default) elif self.display_mode=='scalar': #node colors are not used here, instead the scalar value is set directly self.node_colors=list(self.node_colors_default) elif self.display_mode=='module_single': new_colors=np.tile(.3,self.nr_labels) new_colors[self.get_module()]=.8 self.node_colors=list(self.opts.default_map._pl(new_colors)) elif self.display_mode=='module_multi': while self.nr_modules > len(self.module_colors): i,j=np.random.randint(18,size=(2,)) col=(np.array(self.module_colors[i])+self.module_colors[j])/2 col=np.array(col,dtype=int) self.module_colors.append(col.tolist()) perm=np.random.permutation(len(self.module_colors)) #mayavi scalars depend on saving the module colors self.module_colors=np.array(self.module_colors)[perm].tolist() cols=self.module_colors[:self.nr_modules] import bct ci=bct.ls2ci(self.modules,zeroindexed=True) self.node_colors=((np.array(self.module_colors)[ci])/255).tolist() def draw_conns(self,conservative=False): if conservative: new_edges = None else: new_edges,count_edges = self.select_conns() for data_view in (self.dv_3d, self.dv_mat, self.dv_circ): if data_view is not None: data_view.draw_conns(new_edges) def select_conns(self): lo=self.thresval hi=np.max(self.adjdat) basic_conds=lambda e,a,b:(not self.masked[e] and self.curr_node is None or self.curr_node in (a,b)) if self.display_mode=='module_single': #find the right module module=self.get_module() #attach the right conditions if self.opts.module_view_style=='intramodular': conds = lambda e,a,b:(basic_conds(e,a,b) and (a in module and b in module)) elif self.opts.module_view_style=='intermodular': conds = lambda e,a,b:(basic_conds(e,a,b) and ((a in module) != (b in module))) #xor elif self.opts.module_view_stlye=='both': conds = lambda e,a,b:(basic_conds(e,a,b) and (a in module or b in module)) else: conds=basic_conds new_edges=np.zeros((self.nr_edges,2),dtype=int) count_edges=0 for e,(a,b) in enumerate(zip(self.edges[:,0],self.edges[:,1])): if conds(e,a,b): new_edges[e]=(a,b) #do the threshold checking here. This code breaks the #design spec; the dataset is checking the dataview and #messing with its internals. obviously, the reason why #is that this code runs often and needs to be optimized if self.dv_circ is not None and not self.opts.disable_circle: ev=self.adjdat[e] if (lo <= ev <= hi): self.dv_circ.circ_data[e].set_visible(True) ec=self.opts.activation_map._pl((ev-lo)/(hi-lo)) self.dv_circ.circ_data[e].set_ec(ec) count_edges+=1 else: self.dv_circ.circ_data[e].set_visible(False) else: new_edges[e]=(0,0) if self.dv_circ is not None and not self.opts.disable_circle: self.dv_circ.circ_data[e].set_visible(False) return new_edges,count_edges def center_adjmat(self): self.dv_mat.center() ###################################################################### # I/O METHODS (LOADING, SAVING) ###################################################################### def load_parc(self,lab_pos,labnam,srf,labv): self.lab_pos=lab_pos self.labnam=labnam self.srf=srf self.labv=labv self.nr_labels=len(labnam) #there is no need to call pos_helper_gen here! pos_helper_gen #only has to do with edges. previously it also reset scalars, but we #don't do that in load_adj so there is no reason for it #self.pos_helper_gen() self.node_scalars = {} self.color_legend=ColorLegend() self.node_colors_gen() self.adj=None #whatever adj was before, it is now the wrong size self.reset_dataviews() def load_adj(self,adj,soft_max_edges,reqrois,suppress_extra_rois): self.adj=adj self.soft_max_edges=soft_max_edges #it is necessary to rerun pos_helper_gen() on every load because the #number of edges #is not constant from one adjmat to another and which edges are thrown #away under the soft cap may differ. pos_helper_gen is really all about #edge positions. we wouldnt have to do this if *all* previously #subcutoff are still subcutoff (which is unlikely). #we could also potentially avoid having to do this if we knew that the #parcellation didnt change and only contained nr_edges < soft_max. But #its not worth bothering self.pos_helper_gen() #flip adj ord should already be done to the preprocessed adj self.adj_helper_gen() self.dv_3d.vectors_clear() self.display_mode='normal' self.dv_3d.supply_adj() self.dv_mat.supply_adj() self.dv_circ.supply_adj(reqrois=reqrois, suppress_extra_rois=suppress_extra_rois) self.display_all() #This method takes a TractographyChooserParameters def load_tractography(self,params): if not params.track_file: self.error_dialog('You must specify a valid tractography file'); return if not params.b0_volume: self.error_dialog('You must specify a B0 volume from which the registration' ' to the diffusion space can be computed'); return if not params.subjects_dir or not params.subject: self.error_dialog('You must specify the freesurfer reconstruction for the ' 'individual subject for registration to the surface space.'); return self.dv_3d.tracks_gen(params) #This method takes a GeneralMatrixChooserParameters def load_modules_or_scalars(self,params): if not params.mat: self.error_dialog('You must specify a valid matrix file'); return if params.whichkind=='scalars' and not params.measure_name: self.error_dialog('Cannot leave scalar name blank. cvu uses ' 'this value as a dictionary index'); return import preprocessing try: ci=preprocessing.loadmat(params.mat, field=params.field_name) except (CVUError,IOError) as e: self.error_dialog(str(e)); return if params.mat_order: try: init_ord, bads = preprocessing.read_ordering_file( params.mat_order) except (IndexError,UnicodeDecodeError) as e: self.error_dialog(str(e)); return #delete the bads if not params.ignore_deletes: ci=np.delete(ci,bads) #perform the swapping try: ci_ord = preprocessing.adj_sort(init_ord, self.labnam) except CVUError as e: self.error_dialog(str(e)); return except KeyError as e: self.error_dialog('Field not found: %s'%str(e)); return ci=ci[ci_ord] try: ci=np.reshape(ci,(self.nr_labels,)) except ValueError as e: self.error_dialog('The %s file is of size %i after deletions, but ' 'the dataset has %i regions' % (params.whichkind, len(ci), self.nr_labels)); return if params.whichkind=='modules': import bct self.modules=bct.ci2ls(ci) self.nr_modules=len(self.modules) elif params.whichkind=='scalars': self.save_scalar(params.measure_name,ci) params._dataset_plusplus() #this method destroys the current dataviews and resets them entirely def reset_dataviews(self): #in principle it might be useful to do some more cleanup here self.display_mode='normal' self.dv_3d=DVMayavi(self) self.dv_mat=DVMatrix(self) self.dv_circ=DVCircle(self) self.chg_scalar_colorbar() #scalar colorbar loading is tied to the surface and not to nodes #because the surface always has the same color scheme and the nodes #don't. but it can't be in surfs_gen because the surf can get gen'd #when switching from cracked to glass. so it is here. #handles the scaling and size checking for new scalar datasets def save_scalar(self,name,scalars,passive=False): if np.squeeze(scalars).shape != (self.nr_labels,): if passive: self.verbose_msg("%s: Only Nx1 vectors can be saved as scalars"%name) return else: self.error_dialog("%s: Only Nx1 vectors can be saved as scalars"%name) #print np.squeeze(scalars).shape, self.nr_labels return ci=scalars.ravel().copy() ci=(ci-np.min(ci))/(np.max(ci)-np.min(ci)) self.node_scalars.update({name:ci}) #this function takes a SnapshotParameters object and returns a #continuation -- a closure -- which saves the snapshot. The CVU object #spawns the "Really overwrite file" window if the file exists, and then #calls the continuation, or else just calls the continuation directly. def snapshot(self,params): def save_continuation(): try: if params.whichplot=='3D brain': self.dv_3d.snapshot(params) elif params.whichplot=='connection matrix': self.dv_mat.snapshot(params) elif params.whichplot=='circle plot': self.dv_circ.snapshot(params) except IOError as e: self.error_dialog(str(e)) except KeyError as e: self.error_dialog('The library making the snapshot supports' ' multiple file types and doesnt know which one you want.' ' Please specify a file extension to disambiguate.') return save_continuation #this function takes a MakeMovieParameters object and returns a #continuation which records and takes the movie. The CVU object is again #responsible for thinking about the "really overwrite file" case. def make_movie(self,params): def save_continuation(): self.dv_3d.make_movie(params) return save_continuation def make_movie_finish(self,params): self.dv_3d.make_movie_finish(params) ###################################################################### # VISUALIZATION INTERACTIONS ###################################################################### def display_all(self): self.display_mode='normal' self.curr_node=None self.cur_module=None self.center_adjmat() self.draw() def display_node(self,n): if n<0 or n>=self.nr_labels: return self.curr_node=n self.draw_conns() def display_scalars(self): self.display_mode='scalar' self.draw_surfs() self.draw_nodes() def display_module(self,module): self.display_mode='module_single' self.curr_node=None self.cur_module=module self.draw() #draw surf is needed to unset surf color def display_multi_module(self): if not self.modules: self.error_dialog('No modules defined') return self.display_mode='module_multi' self.draw_nodes() def calculate_modules(self,thres): import bct thres_adj=self.adj.copy() thres_adj[thres_adj < thres] = 0 self.verbose_msg('Threshold for modularity calculation: %s'%str(thres)) modvec,_=bct.modularity_und(thres_adj) self.modules = bct.ci2ls(modvec) self.nr_modules = len(self.modules) def calculate_graph_stats(self,thres): import graph,bct thres_adj = self.adj.copy() thres_adj[thres_adj < thres] = 0 self.verbose_msg('Threshold for graph calculations: %s'%str(thres)) try: self.graph_stats=graph.do_summary(thres_adj,bct.ls2ci(self.modules), self.opts.intermediate_graphopts_list) for name,arr in self.graph_stats.iteritems(): self.save_scalar(name,arr,passive=True) except CVUError: self.error_dialog("Community structure required for some of " "the calculations specified. Try calculating modules first.") #save_graphstat_to_scalar is principally interaction between window and #dataset manager. it should call a generic save_scalar method, same as #loading a scalar from a file ###################################################################### # OPTIONS ###################################################################### def prop_thresh(self): try: self.thresval=float(self.adjdat[ int(round(self.opts.pthresh*self.nr_edges-1))]) except TraitError as e: if self.opts.pthresh>1: self.warning_dialog("%s\nThreshold set to maximum"%str(e)) elif self.opts.pthresh<0: self.warning_dialog("%s\nThreshold set to minimum"%str(e)) else: self.error_dialog("Programming error") def abs_thresh(self): self.thresval=self.opts.athresh if self.adjdat[self.nr_edges-1] < self.opts.athresh: self.thresval=self.adjdat[self.nr_edges-1] self.warning_dialog("Threshold over maximum! Set to maximum.") elif self.adjdat[0] > self.opts.athresh: self.thresval=self.adjdat[0] self.warning_dialog("Threshold under minimum! Set to minimum.") #recall reset thresh is a cached property @on_trait_change('opts:pthresh') def chg_pthresh_val(self): if self.opts.thresh_type != 'prop': return self.reset_thresh() self.draw_conns(conservative=True) @on_trait_change('opts:athresh') def chg_athresh_val(self): if self.opts.thresh_type != 'abs': return self.reset_thresh() self.draw_conns(conservative=True) @on_trait_change('opts:thresh_type') def chg_thresh_type(self): self.draw_conns(conservative=True) @on_trait_change('opts:interhemi_conns_on') def chg_interhemi_connmask(self): self.masked[self.interhemi]=not self.opts.interhemi_conns_on @on_trait_change('opts:lh_conns_on') def chg_lh_connmask(self): self.masked[self.left]=not self.opts.lh_conns_on @on_trait_change('opts:rh_conns_on') def chg_rh_connmask(self): self.masked[self.right]=not self.opts.rh_conns_on #the following options operate on specific views only #they may fail if the view is not present (i dont know if this is true) @on_trait_change('opts:circ_size') def chg_circ_size(self): try: self.dv_circ.circ.axes[0].set_ylim(0,self.opts.circ_size) self.dv_circ.circ.canvas.draw() except AttributeError: pass @on_trait_change('opts:show_floating_text') def chg_float_text(self): try: self.dv_3d.txt.visible=self.opts.show_floating_text except AttributeError: pass @on_trait_change('opts:scalar_colorbar') def chg_scalar_colorbar(self): try: self.dv_3d.set_colorbar(self.opts.scalar_colorbar, self.dv_3d.syrf_lh, orientation='vertical') except AttributeError: pass @on_trait_change('opts:render_style') def chg_render_style(self): try: self.dv_3d.set_surf_render_style(self.opts.render_style) except AttributeError: pass @on_trait_change('opts:surface_visibility') def chg_surf_opacity(self): try: for syrf in (self.dv_3d.syrf_lh, self.dv_3d.syrf_rh): syrf.actor.property.opacity=self.opts.surface_visibility except AttributeError: pass @on_trait_change('opts:lh_nodes_on') def chg_lh_nodemask(self): try: self.dv_3d.nodes_lh.visible=self.opts.lh_nodes_on except AttributeError: pass @on_trait_change('opts:rh_nodes_on') def chg_rh_nodemask(self): try: self.dv_3d.nodes_rh.visible=self.opts.rh_nodes_on except AttributeError: pass @on_trait_change('opts:lh_surfs_on') def chg_lh_surfmask(self): try: self.dv_3d.syrf_lh.visible=self.opts.lh_surfs_on except AttributeError: pass @on_trait_change('opts:rh_surfs_on') def chg_rh_surfmask(self): try: self.dv_3d.syrf_rh.visible=self.opts.rh_surfs_on except AttributeError: pass @on_trait_change('opts:conns_colors_on') def chg_conns_colors(self): try: if self.opts.conns_colors_on: self.dv_3d.vectors.glyph.color_mode='color_by_scalar' else: self.dv_3d.vectors.glyph.color_mode='no_coloring' except AttributeError: pass @on_trait_change('opts:conns_colorbar') def chg_conns_colorbar(self): try: self.dv_3d.set_colorbar(self.opts.conns_colorbar, self.dv_3d.vectors, orientation='horizontal') except AttributeError: pass @on_trait_change('opts:conns_width') def chg_conns_width(self): try: self.dv_3d.vectors.actor.property.line_width=self.opts.conns_width except AttributeError: pass @on_trait_change('opts:default_map.[cmap,reverse,fname,threshold]') def chg_default_map(self): try: self.draw_nodes() except: map_def = self.opts.default_map if map_def.cmap == 'file' and not map_def.fname: pass else: raise @on_trait_change('opts:scalar_map.[cmap,reverse,fname,threshold]') def chg_scalar_map(self): try: self.draw_surfs() self.draw_nodes() except: map_sca = self.opts.scalar_map if map_sca.cmap == 'file' and not map_sca.fname: pass else: raise @on_trait_change('opts:activation_map.[cmap,reverse,fname,threshold]') def chg_activation_map(self): #we don't touch the circle plot here since circle redraw is expensive try: self.dv_3d.draw_conns() except: map_act = self.opts.activation_map if map_act.cmap == 'file' and not map_act.fname: pass else: raise @on_trait_change('opts:connmat_map.[cmap,reverse,fname,threshold]') def chg_connmat_map(self): try: self.dv_mat.change_colormap() except: map_mat = self.opts.connmat_map if map_mat.cmap == 'file' and not map_mat.fname: pass else: raise ###################################################################### # MISCELLANEOUS HELPERS ###################################################################### def error_dialog(self,str): return self.gui.error_dialog(str) def warning_dialog(self,str): return self.gui.warning_dialog(str) def verbose_msg(self,str): return self.gui.verbose_msg(str) def get_module(self): if self.cur_module=='custom': return self.custom_module elif isinstance(self.cur_module,int) and self.modules is not None: return self.modules[self.cur_module]
class Dataset(HasTraits): ######################################################################## # FUNDAMENTALLY NECESSARY DATA ######################################################################## name = Str #give this dataset a name gui = Any #symbolic reference to a modular cvu nr_labels = Int nr_edges = Int labnam = List(Str) #adjlabfile=File #the adjlabfile is not needed. this is only kept on hand to pass it #around if specified as CLI arg. it is only used upon loading an adjmat. #so just have adjmat loading be a part of dataset creation and get rid of #keeping track of this adj = Any #NxN np.ndarray adj_thresdiag = Property(depends_on='adj') #NxN np.ndarray @cached_property def _get_adj_thresdiag(self): adjt = self.adj.copy() adjt[np.where(np.eye(self.nr_labels))] = np.min(adjt[np.where(adjt)]) return adjt starts = Any #Ex3 np.ndarray vecs = Any #Ex3 np.ndarray edges = Any #Ex2 np.ndarray(int) srf = Instance(SurfData) #labv=List(Instance(mne.Label)) #all that is needed from this is a map of name->vertex #this is a considerable portion of the data contained in a label but still #only perhaps 15%. To make lightweight, extract this from labv #TODO convert it to that labv = Dict #is an OrderedDict in parcellation order lab_pos = Any #Nx3 np.ndarray ######################################################################### # CRITICAL NONADJUSTABLE DATA WITHOUT WHICH DISPLAY CANNOT EXIST ######################################################################### dv_3d = Either(Instance(DataView), None) dv_mat = Either(Instance(DataView), None) dv_circ = Either(Instance(DataView), None) soft_max_edges = Int adjdat = Any #Ex1 np.ndarray left = Any #Nx1 np.ndarray(bool) right = Any #Nx1 np.ndarray(bool) interhemi = Any #Nx1 np.ndarray(bool) masked = Any #Nx1 np.ndarray(bool) lhnodes = Property(depends_on='labnam') #Nx1 np.ndarray(int) rhnodes = Property(depends_on='labnam') #Nx1 np.ndarray(int) @cached_property def _get_lhnodes(self): return np.where(map(lambda r: r[0] == 'l', self.labnam))[0] @cached_property def _get_rhnodes(self): return np.where(map(lambda r: r[0] == 'r', self.labnam))[0] node_colors = Any #Nx3 np.ndarray #node_colors represents the colors held by the nodes. the current value of #node_colors depends on the current policy (i.e. the current display mode). #however, don't take this all too literally. depending on the current #policy, the dataviews may choose to ignore what is in node_colors and #use some different color. #this is always true of Mayavi views, who can't use the node colors at all. #because mayavi doesn't play nice with true colors (this could be fixed #if mayavi is fixed). it is also true of the other plots in scalar mode, #but when scalars are not specified for those dataviews. #node_colors_default will be set uniquely for each parcellation and can #thus be different for different datasets. #group_colors is more subtle; it can in principle be set uniquely for each #parcellation as long as the parcellations don't conform to aparc. for #instance, destrieux parc has different group colors. right now i'm a long #way away from dealing with this but i think in a month it will be prudent #to just have the dataset capture both of these variables node_colors_default = List node_labels_numberless = List(Str) group_colors = List nr_groups = Int group_labels = List(Str) color_legend = Instance(ColorLegend) module_colors = List default_glass_brain_color = Constant((.82, .82, .82)) ######################################################################### # ASSOCIATED STATISTICAL AND ANALYTICAL DATA ######################################################################### node_scalars = Dict scalar_display_settings = Instance(ScalarDisplaySettings) #TODO make modules a dictionary modules = List nr_modules = Int graph_stats = Dict ######################################################################### # ASSOCIATED DISPLAY OPTIONS AND DISPLAY STATE (ADJUSTABLE/TRANSIENT) ######################################################################### opts = Instance(DisplayOptions) display_mode = Enum('normal', 'scalar', 'module_single', 'module_multi') reset_thresh = Property(Method) def _get_reset_thresh(self): if self.opts.thresh_type == 'prop': return self.prop_thresh elif self.opts.thresh_type == 'abs': return self.abs_thresh thresval = Float curr_node = Either(Int, None) cur_module = Either(Int, 'custom', None) custom_module = List ######################################################################## # SETUP ######################################################################## def __init__(self, name, lab_pos, labnam, srf, labv, gui=None, adj=None, soft_max_edges=20000, **kwargs): super(Dataset, self).__init__(**kwargs) self.gui = gui self.name = name self.opts = DisplayOptions(self) self.scalar_display_settings = ScalarDisplaySettings(self) #this is effectively load_parc self.lab_pos = lab_pos self.labnam = labnam self.srf = srf self.labv = labv self.nr_labels = len(labnam) #load_parc redundantly sets the current display but oh well. #self.load_parc(lab_pos,labnam,srf,labv, # init_display.subject_name,init_display.parc_name) #if adj is None, it means it will be guaranteed to be supplied later #by the user #this is load adj, except without initializing nonexistent dataviews if adj is not None: self.adj = adj self.soft_max_edges = soft_max_edges self.pos_helper_gen() #flip adj ord should already be done to the preprocessed adj self.adj_helper_gen() self.color_legend = ColorLegend() self.node_colors_gen() self.dv_3d = DVMayavi(self) self.dv_mat = DVMatrix(self) self.dv_circ = DVCircle(self) self.chg_scalar_colorbar() def __repr__(self): return 'Dataset: %s' % self.name def __getitem__(self, key): if key == 0: return self elif key == 1: return self.name else: raise KeyError( 'Invalid indexing to dataset. Dataset indexing ' 'is implemented to appease CheckListEditor and can only be 0 or 1.' ) ######################################################################## # GEN METHODS ######################################################################## #preconditions: lab_pos has been set. def pos_helper_gen(self, reset_scalars=True): self.nr_labels = n = len(self.lab_pos) self.nr_edges = self.nr_labels * (self.nr_labels - 1) // 2 #self.starts = np.zeros((self.nr_edges,3),dtype=float) #self.vecs = np.zeros((self.nr_edges,3),dtype=float) #self.edges = np.zeros((self.nr_edges,2),dtype=int) #i=0 #for r2 in xrange(0,self.nr_labels,1): # for r1 in xrange(0,r2,1): #self.starts[i,:] = self.lab_pos[r1] #self.vecs[i,:] = self.lab_pos[r2]-self.lab_pos[r1] #self.edges[i,0],self.edges[i,1] = r1,r2 #i+=1 tri_ixes = np.triu(np.ones((n, n)), 1) ixes, = np.where(tri_ixes.flat) A_r = np.tile(self.lab_pos, (n, 1, 1)) self.starts = np.reshape(A_r, (n * n, 3))[ixes, :] self.vecs = np.reshape(A_r - np.transpose(A_r, (1, 0, 2)), (n * n, 3))[ixes, :] self.edges = np.transpose(np.where(tri_ixes.T))[:, ::-1] #pos_helper_gen is now only called from load adj. The reason it is #because it can change on all adj changes because of the soft #cap. The number of edges can differ between adjmats because of the #soft cap and all of the positions need to be recalculated if it does. #pos_helper_gen really only has to do with edge positions. Node and #surf positions dont depend on it at all. #TODO possibly, keep track of the soft cap and do nothing if it hasn't #changed #RESPONSE: yes but this check should be done in adj_load if reset_scalars: self.node_scalars = {} self.display_mode = 'normal' #precondition: adj_helper_gen() must be run after pos_helper_gen() def adj_helper_gen(self): self.nr_edges = self.nr_labels * (self.nr_labels - 1) // 2 self.adjdat = np.zeros((self.nr_edges), dtype=float) self.interhemi = np.zeros((self.nr_edges), dtype=bool) self.left = np.zeros((self.nr_edges), dtype=bool) self.right = np.zeros((self.nr_edges), dtype=bool) self.masked = np.zeros((self.nr_edges), dtype=bool) i = 0 self.adj[xrange(self.nr_labels), xrange(self.nr_labels)] = 0 #for r2 in xrange(0,self.nr_labels,1): #self.adj[r2][r2]=0 #for r1 in xrange(0,r2,1): #self.adjdat[i] = self.adj[r1][r2] #self.interhemi[i] = self.labnam[r1][0] != self.labnam[r2][0] #self.left[i] = self.labnam[r1][0]==self.labnam[r2][0]=='l' #self.right[i] = self.labnam[r1][0]==self.labnam[r2][0]=='r' #i+=1 n = self.nr_labels ixes, = np.where(np.triu(np.ones((n, n)), 1).flat) self.adjdat = self.adj.flat[::-1][ixes][::-1] from parsing_utils import same_hemi sh = np.vectorize(same_hemi) L_r = np.tile(self.labnam, (self.nr_labels, 1)) self.interhemi = np.logical_not(sh(L_r, L_r.T)).flat[::-1][ixes][::-1] self.left = sh(L_r, L_r.T, 'l').flat[::-1][ixes][::-1] self.right = sh(L_r, L_r.T, 'r').flat[::-1][ixes][::-1] #remove all but the soft_max_edges largest connections if self.nr_edges > self.soft_max_edges: cutoff = sorted(self.adjdat)[self.nr_edges - self.soft_max_edges - 1] zi = np.where(self.adjdat >= cutoff) # if way way too many edges remain, make it a hard max # this happens in DTI data which is very sparse, the cutoff is 0 if len(zi[0]) > (self.soft_max_edges + 200): zi = np.where(self.adjdat > cutoff) self.starts = self.starts[zi[0], :] self.vecs = self.vecs[zi[0], :] self.edges = self.edges[zi[0], :] self.adjdat = self.adjdat[zi[0]] self.interhemi = self.interhemi[zi[0]] self.left = self.left[zi[0]] self.right = self.right[zi[0]] self.nr_edges = len(self.adjdat) self.verbose_msg(str(self.nr_edges) + " total connections") #sort the adjdat sort_idx = np.argsort(self.adjdat, axis=0) self.adjdat = self.adjdat[sort_idx].squeeze() self.edges = self.edges[sort_idx].squeeze() self.starts = self.starts[sort_idx].squeeze() self.vecs = self.vecs[sort_idx].squeeze() self.left = self.left[sort_idx].squeeze() self.right = self.right[sort_idx].squeeze() self.interhemi = self.interhemi[sort_idx].squeeze() self.masked = self.masked[sort_idx].squeeze() #just to prune #try to auto-set the threshold to a reasonable value if self.nr_edges < 500: self.opts.pthresh = .01 else: thr = (self.nr_edges - 500) / (self.nr_edges) self.opts.pthresh = thr self.opts.thresh_type = 'prop' self.display_mode = 'normal' def node_colors_gen(self): #node groups could change upon loading a new parcellation hi_contrast_clist = ('#26ed1a', '#eaf60b', '#e726f4', '#002aff', '#05d5d5', '#f4a5e0', '#bbb27e', '#641179', '#068c40') hi_contrast_cmap = LinearSegmentedColormap.from_list( 'hi_contrast', hi_contrast_clist) #labels are assumed to start with lh_ and rh_ self.node_labels_numberless = map( lambda n: n.replace('div', '').strip('1234567890_'), self.labnam) node_groups = map(lambda n: n[3:], self.node_labels_numberless) #put group names in ordered set #n_set=set() #self.group_labels=( # [i for i in node_groups if i not in n_set and not n_set.add(i)]) node_groups_hemi1 = map( lambda n: n[3:], self.node_labels_numberless[:len(self.lhnodes)]) node_groups_hemi2 = map( lambda n: n[3:], self.node_labels_numberless[-len(self.rhnodes):]) a_set = set() self.group_labels = ([ i for i in node_groups_hemi1 if not i in a_set and not a_set.add(i) ]) last_grp = None for grp in node_groups_hemi2: if grp not in self.group_labels: if last_grp is None: self.group_labels.insert(grp, 0) else: self.group_labels.insert( self.group_labels.index(last_grp) + 1, grp) else: last_grp = grp self.nr_groups = len(self.group_labels) #get map of {node name -> node group} grp_ids = dict(zip(self.group_labels, xrange(self.nr_groups))) #group colors does not change unless the parcellation is reloaded self.group_colors = ([ hi_contrast_cmap(i / self.nr_groups) for i in range(self.nr_groups) ]) #node colors changes constantly, so copy and stash the result self.node_colors = map(lambda n: self.group_colors[grp_ids[n]], node_groups) self.node_colors_default = list(self.node_colors) #create the color legend associated with this dataset def create_color_legend_entry(zipped): label, color = zipped return LegendEntry(metaregion=label, col=color) self.color_legend.entries = map( create_color_legend_entry, zip(self.group_labels, self.group_colors)) #set up some colors that are acceptably high contrast for modules #this is unrelated to node colors in any way, for multi-module mode self.module_colors = ([[255, 255, 255, 255], [204, 0, 0, 255], [51, 204, 51, 255], [66, 0, 204, 255], [80, 230, 230, 255], [51, 153, 255, 255], [255, 181, 255, 255], [255, 163, 71, 255], [221, 221, 149, 255], [183, 230, 46, 255], [77, 219, 184, 255], [255, 255, 204, 255], [0, 0, 204, 255], [204, 69, 153, 255], [255, 255, 0, 255], [0, 128, 0, 255], [163, 117, 25, 255], [255, 25, 117, 255]]) ###################################################################### # DRAW METHODS ###################################################################### def draw(self, skip_circ=False): self.draw_surfs() self.draw_nodes(skip_circ=skip_circ) self.draw_conns(skip_circ=skip_circ) def draw_surfs(self): for data_view in (self.dv_3d, self.dv_mat, self.dv_circ): data_view.draw_surfs() def draw_nodes(self, skip_circ=False): self.set_node_colors() for data_view in (self.dv_3d, self.dv_mat, self.dv_circ): if skip_circ and data_view is self.dv_circ: continue data_view.draw_nodes() def set_node_colors(self): #set node_colors if self.display_mode == 'normal': self.node_colors = list(self.node_colors_default) elif self.display_mode == 'scalar': #node colors are not used here, instead the scalar value is set directly self.node_colors = list(self.node_colors_default) elif self.display_mode == 'module_single': new_colors = np.tile(.3, self.nr_labels) new_colors[self.get_module()] = .8 self.node_colors = list(self.opts.default_map._pl(new_colors)) elif self.display_mode == 'module_multi': while self.nr_modules > len(self.module_colors): i, j = np.random.randint(18, size=(2, )) col = (np.array(self.module_colors[i]) + self.module_colors[j]) / 2 col = np.array(col, dtype=int) self.module_colors.append(col.tolist()) #perm=np.random.permutation(len(self.module_colors)) #mayavi scalars depend on saving the module colors #self.module_colors=np.array(self.module_colors)[perm].tolist() cols = self.module_colors[:self.nr_modules] import bct ci = bct.ls2ci(self.modules, zeroindexed=True) self.node_colors = ((np.array(self.module_colors)[ci]) / 255).tolist() def draw_conns(self, conservative=False, skip_circ=False): if conservative: new_edges = None else: new_edges, count_edges = self.select_conns(skip_circ=skip_circ) for data_view in (self.dv_3d, self.dv_mat, self.dv_circ): if skip_circ and data_view is self.dv_circ: continue elif data_view is not None: data_view.draw_conns(new_edges) def select_conns(self, skip_circ=False): disable_circle = (skip_circ or self.opts.circle_render == 'disabled') lo = self.thresval hi = np.max(self.adjdat) basic_conds = lambda e, a, b: (not self.masked[e] and self.curr_node is None or self.curr_node in (a, b)) if self.display_mode == 'module_single': #find the right module module = self.get_module() #attach the right conditions if self.opts.module_view_style == 'intramodular': conds = lambda e, a, b: (basic_conds(e, a, b) and (a in module and b in module)) elif self.opts.module_view_style == 'intermodular': conds = lambda e, a, b: (basic_conds(e, a, b) and ( (a in module) != (b in module))) #xor elif self.opts.module_view_stlye == 'both': conds = lambda e, a, b: (basic_conds(e, a, b) and (a in module or b in module)) else: conds = basic_conds new_edges = np.zeros((self.nr_edges, 2), dtype=int) count_edges = 0 for e, (a, b) in enumerate(zip(self.edges[:, 0], self.edges[:, 1])): if conds(e, a, b): new_edges[e] = (a, b) #do the threshold checking here. This code breaks the #design spec; the dataset is checking the dataview and #messing with its internals. obviously, the reason why #is that this code runs often and needs to be optimized if self.dv_circ is not None and not disable_circle: ev = self.adjdat[e] if (lo <= ev <= hi): self.dv_circ.circ_data[e].set_visible(True) ec = self.opts.activation_map._pl( (ev - lo) / (hi - lo)) self.dv_circ.circ_data[e].set_ec(ec) count_edges += 1 else: self.dv_circ.circ_data[e].set_visible(False) else: new_edges[e] = (0, 0) if self.dv_circ is not None and not disable_circle: self.dv_circ.circ_data[e].set_visible(False) return new_edges, count_edges def center_adjmat(self): self.dv_mat.center() ###################################################################### # I/O METHODS (LOADING, SAVING) ###################################################################### def _load_parc(self, lab_pos, labnam, srf, labv): self.lab_pos = lab_pos self.labnam = labnam self.srf = srf self.labv = labv self.nr_labels = len(labnam) #there is no need to call pos_helper_gen here! pos_helper_gen #only has to do with edges. previously it also reset scalars, but we #don't do that in load_adj so there is no reason for it #self.pos_helper_gen() self.node_scalars = {} self.color_legend = ColorLegend() self.node_colors_gen() self.adj = None #whatever adj was before, it is now the wrong size self.reset_dataviews() def _load_adj(self, adj, soft_max_edges, reqrois, suppress_extra_rois): self.adj = adj self.soft_max_edges = soft_max_edges #it is necessary to rerun pos_helper_gen() on every load because the #number of edges #is not constant from one adjmat to another and which edges are thrown #away under the soft cap may differ. pos_helper_gen is really all about #edge positions. we wouldnt have to do this if *all* previously #subcutoff are still subcutoff (which is unlikely). #we could also potentially avoid having to do this if we knew that the #parcellation didnt change and only contained nr_edges < soft_max. But #its not worth bothering self.pos_helper_gen() #flip adj ord should already be done to the preprocessed adj self.adj_helper_gen() self.dv_3d.vectors_clear() self.display_mode = 'normal' self.dv_3d.supply_adj() self.dv_mat.supply_adj() if self.opts.circle_render == 'asynchronous': #first set up the 3D brain properly and then set the circle #to generate itself in a background thread self.display_all(skip_circ=True) self.dv_3d.zaxis_view() def threadsafe_circle_setup(): self.dv_circ.supply_adj( reqrois=reqrois, suppress_extra_rois=suppress_extra_rois) self.select_conns() self.dv_circ.draw_conns() Thread(target=threadsafe_circle_setup).start() else: #otherwise set up the circle and display everything in a #single thread. If the circle is disabled this will not cause problems self.dv_circ.supply_adj(reqrois=reqrois, suppress_extra_rois=suppress_extra_rois) self.display_all() self.dv_3d.zaxis_view() #This method takes a TractographyChooserParameters def load_tractography(self, params): if not params.track_file: self.error_dialog('You must specify a valid tractography file') return if not params.b0_volume: self.error_dialog( 'You must specify a B0 volume from which the ' 'registration to the diffusion space can be computed') return if not params.subjects_dir or not params.subject: self.error_dialog( 'You must specify the freesurfer reconstruction ' 'for the individual subject for registration to the surface ' 'space.') return self.dv_3d.tracks_gen(params) #This method takes a GeneralMatrixChooserParameters def load_modules_or_scalars(self, params): if not params.mat: self.error_dialog('You must specify a valid matrix file') return if params.whichkind == 'scalars' and not params.measure_name: self.error_dialog('Cannot leave scalar name blank. cvu uses ' 'this value as a dictionary index') return import preprocessing try: ci = preprocessing.loadmat(params.mat, field=params.field_name, is_adjmat=False) except (CVUError, IOError) as e: self.error_dialog(str(e)) return if params.mat_order: try: init_ord, bads = preprocessing.read_ordering_file( params.mat_order) except (IndexError, UnicodeDecodeError) as e: self.error_dialog(str(e)) return #delete the bads if not params.ignore_deletes: ci = np.delete(ci, bads) #perform the swapping try: ci_ord = preprocessing.adj_sort(init_ord, self.labnam) except CVUError as e: self.error_dialog(str(e)) return except KeyError as e: self.error_dialog('Field not found: %s' % str(e)) return ci = ci[ci_ord] try: ci = np.reshape(ci, (self.nr_labels, )) except ValueError as e: self.error_dialog('The %s file is of size %i after deletions, but ' 'the dataset has %i regions' % (params.whichkind, len(ci), self.nr_labels)) return if params.whichkind == 'modules': import bct self.modules = bct.ci2ls(ci) self.nr_modules = len(self.modules) elif params.whichkind == 'scalars': self.save_scalar(params.measure_name, ci) params._increment_scalar_count() #this method destroys the current dataviews and resets them entirely def reset_dataviews(self): #in principle it might be useful to do some more cleanup here self.display_mode = 'normal' self.dv_3d = DVMayavi(self) self.dv_mat = DVMatrix(self) self.dv_circ = DVCircle(self) self.chg_scalar_colorbar() #scalar colorbar loading is tied to the surface and not to nodes #because the surface always has the same color scheme and the nodes #don't. but it can't be in surfs_gen because the surf can get gen'd #when switching from cracked to glass. so it is here. #handles the scaling and size checking for new scalar datasets def save_scalar(self, name, scalars, passive=False): if np.squeeze(scalars).shape != (self.nr_labels, ): if passive: self.verbose_msg( "%s: Only Nx1 vectors can be saved as scalars" % name) return else: self.error_dialog( "%s: Only Nx1 vectors can be saved as scalars" % name) #print np.squeeze(scalars).shape, self.nr_labels return ci = scalars.ravel().copy() #ci=(ci-np.min(ci))/(np.max(ci)-np.min(ci)) self.node_scalars.update({name: ci}) #this function takes a SnapshotParameters object and returns a #continuation -- a closure -- which saves the snapshot. The CVU object #spawns the "Really overwrite file" window if the file exists, and then #calls the continuation, or else just calls the continuation directly. def snapshot(self, params): def save_continuation(): try: if params.whichplot == '3D brain': self.dv_3d.snapshot(params) elif params.whichplot == 'connection matrix': self.dv_mat.snapshot(params) elif params.whichplot == 'circle plot': self.dv_circ.snapshot(params) except IOError as e: self.error_dialog(str(e)) except KeyError as e: self.error_dialog( 'The library making the snapshot supports' ' multiple file types and doesnt know which one you want.' ' Please specify a file extension to disambiguate.') return save_continuation #this function takes a MakeMovieParameters object and returns a #continuation which records and takes the movie. The CVU object is again #responsible for thinking about the "really overwrite file" case. def make_movie(self, params): def save_continuation(): self.dv_3d.make_movie(params) return save_continuation def make_movie_finish(self, params): self.dv_3d.make_movie_finish(params) ###################################################################### # VISUALIZATION INTERACTIONS ###################################################################### def display_all(self, skip_circ=False): self.display_mode = 'normal' self.curr_node = None self.cur_module = None self.center_adjmat() self.draw(skip_circ=skip_circ) def display_node(self, n): if n < 0 or n >= self.nr_labels: return self.curr_node = n self.draw_conns() def display_scalars(self): self.display_mode = 'scalar' self.draw_surfs() self.draw_nodes() def display_module(self, module): self.display_mode = 'module_single' self.curr_node = None self.cur_module = module self.draw() #draw surf is needed to unset surf color def display_multi_module(self): if not self.modules: self.error_dialog('No modules defined') return self.display_mode = 'module_multi' self.draw_surfs() self.draw_nodes() def calculate_modules(self, thres): import graph, bct thres_adj = self.adj.copy() thres_adj[thres_adj < thres] = 0 self.verbose_msg('Threshold for modularity calculation: %s' % str(thres)) modvec = graph.calculate_modules(thres_adj) self.modules = bct.ci2ls(modvec) self.nr_modules = len(self.modules) def calculate_graph_stats(self, thres): import graph, bct thres_adj = self.adj.copy() thres_adj[thres_adj < thres] = 0 self.verbose_msg('Threshold for graph calculations: %s' % str(thres)) try: self.graph_stats = graph.do_summary( thres_adj, bct.ls2ci(self.modules), self.opts.intermediate_graphopts_list) for name, arr in self.graph_stats.iteritems(): self.save_scalar(name, arr, passive=True) except CVUError: self.error_dialog( "Community structure required for some of " "the calculations specified. Try calculating modules first.") #save_graphstat_to_scalar is principally interaction between window and #dataset manager. it should call a generic save_scalar method, same as #loading a scalar from a file ###################################################################### # OPTIONS ###################################################################### def prop_thresh(self): try: if int(round(self.opts.pthresh * self.nr_edges - 1)) < 0: self.thresval = float(self.adjdat[0]) else: self.thresval = float(self.adjdat[int( round(self.opts.pthresh * self.nr_edges - 1))]) except TraitError as e: if self.opts.pthresh > 1: self.warning_dialog("%s\nThreshold set to maximum" % str(e)) elif self.opts.pthresh < 0: self.warning_dialog("%s\nThreshold set to minimum" % str(e)) else: self.error_dialog("Programming error") def abs_thresh(self): self.thresval = self.opts.athresh if self.adjdat[self.nr_edges - 1] < self.opts.athresh: self.thresval = self.adjdat[self.nr_edges - 1] self.warning_dialog("Threshold over maximum! Set to maximum.") elif self.adjdat[0] > self.opts.athresh: self.thresval = self.adjdat[0] self.warning_dialog("Threshold under minimum! Set to minimum.") #recall reset thresh is a cached property @on_trait_change('opts:pthresh') def chg_pthresh_val(self): if self.opts.thresh_type != 'prop': return self.reset_thresh() self.draw_conns(conservative=True) @on_trait_change('opts:athresh') def chg_athresh_val(self): if self.opts.thresh_type != 'abs': return self.reset_thresh() self.draw_conns(conservative=True) @on_trait_change('opts:thresh_type') def chg_thresh_type(self): self.draw_conns(conservative=True) @on_trait_change('opts:interhemi_conns_on') def chg_interhemi_connmask(self): self.masked[self.interhemi] = not self.opts.interhemi_conns_on @on_trait_change('opts:lh_conns_on') def chg_lh_connmask(self): self.masked[self.left] = not self.opts.lh_conns_on @on_trait_change('opts:rh_conns_on') def chg_rh_connmask(self): self.masked[self.right] = not self.opts.rh_conns_on #the following options operate on specific views only #they may fail if the view is not present (i dont know if this is true) @on_trait_change('opts:tube_conns') def chg_tube_conns(self): try: self.dv_3d.set_tubular_properties() except AttributeError: pass @on_trait_change('opts:circ_size') def chg_circ_size(self): try: self.dv_circ.circ.axes[0].set_ylim(0, self.opts.circ_size) self.dv_circ.circ.canvas.draw() except AttributeError: pass @on_trait_change('opts:show_floating_text') def chg_float_text(self): try: self.dv_3d.txt.visible = self.opts.show_floating_text except AttributeError: pass @on_trait_change('opts:scalar_colorbar') def chg_scalar_colorbar(self): try: self.dv_3d.set_colorbar(self.opts.scalar_colorbar, self.dv_3d.syrf_lh, orientation='vertical') except AttributeError: pass @on_trait_change('opts:render_style') def chg_render_style(self): try: self.dv_3d.set_surf_render_style(self.opts.render_style) except AttributeError: pass @on_trait_change('opts:surface_visibility') def chg_surf_opacity(self): try: for syrf in (self.dv_3d.syrf_lh, self.dv_3d.syrf_rh): syrf.actor.property.opacity = self.opts.surface_visibility except AttributeError: pass @on_trait_change('opts:lh_nodes_on') def chg_lh_nodemask(self): try: self.dv_3d.nodes_lh.visible = self.opts.lh_nodes_on except AttributeError: pass @on_trait_change('opts:rh_nodes_on') def chg_rh_nodemask(self): try: self.dv_3d.nodes_rh.visible = self.opts.rh_nodes_on except AttributeError: pass @on_trait_change('opts:lh_surfs_on') def chg_lh_surfmask(self): try: self.dv_3d.syrf_lh.visible = self.opts.lh_surfs_on except AttributeError: pass @on_trait_change('opts:rh_surfs_on') def chg_rh_surfmask(self): try: self.dv_3d.syrf_rh.visible = self.opts.rh_surfs_on except AttributeError: pass @on_trait_change('opts:conns_colors_on') def chg_conns_colors(self): try: if self.opts.conns_colors_on: self.dv_3d.vectors.glyph.color_mode = 'color_by_scalar' else: self.dv_3d.vectors.glyph.color_mode = 'no_coloring' except AttributeError: pass @on_trait_change('opts:conns_colorbar') def chg_conns_colorbar(self): try: self.dv_3d.set_colorbar(self.opts.conns_colorbar, self.dv_3d.vectors, orientation='horizontal') except AttributeError: pass @on_trait_change('opts:conns_width') def chg_conns_width(self): try: self.dv_3d.vectors.actor.property.line_width = self.opts.conns_width except AttributeError: pass @on_trait_change('opts:default_map.[cmap,reverse,fname,threshold]') def chg_default_map(self): try: self.draw_nodes() except: map_def = self.opts.default_map if map_def.cmap == 'file' and not map_def.fname: pass else: raise @on_trait_change('opts:scalar_map.[cmap,reverse,fname,threshold]') def chg_scalar_map(self): try: self.draw_surfs() self.draw_nodes() except: map_sca = self.opts.scalar_map if map_sca.cmap == 'file' and not map_sca.fname: pass else: raise @on_trait_change('opts:activation_map.[cmap,reverse,fname,threshold]') def chg_activation_map(self): #we don't touch the circle plot here since circle redraw is expensive try: self.dv_3d.draw_conns() except: map_act = self.opts.activation_map if map_act.cmap == 'file' and not map_act.fname: pass else: raise @on_trait_change('opts:connmat_map.[cmap,reverse,fname,threshold]') def chg_connmat_map(self): try: self.dv_mat.change_colormap() except: map_mat = self.opts.connmat_map if map_mat.cmap == 'file' and not map_mat.fname: pass else: raise ###################################################################### # MISCELLANEOUS HELPERS ###################################################################### def error_dialog(self, str): return self.gui.error_dialog(str) def warning_dialog(self, str): return self.gui.warning_dialog(str) def verbose_msg(self, str): return self.gui.verbose_msg(str) def get_module(self): if self.cur_module == 'custom': return self.custom_module elif isinstance(self.cur_module, int) and self.modules is not None: return self.modules[self.cur_module]