Ejemplo n.º 1
0
def loadmodel(basedir, ptcode, ctcode, cropname, modelname='modelavgreg'):
    """ Load stent model. An ssdf struct is returned.
    """
    # Load
    fname = '%s_%s_%s_%s.ssdf' % (ptcode, ctcode, cropname, modelname)
    try:
        s = ssdf.load(os.path.join(basedir, ptcode, fname))
    except FileNotFoundError:
        s = ssdf.load(os.path.join(basedir, fname))
    # Turn into graph model
    from stentseg.stentdirect import stentgraph
    for key in dir(s):
        if key.startswith('model'):
            model = stentgraph.StentGraph()
            model.unpack(s[key])
            s[key] = model
        # also unpack seeds if exist
        #-mind that also seeds need to be packed before save is possible again-
        elif key.startswith('seeds'):
            seeds = stentgraph.StentGraph()
            seeds.unpack(s[key])
            s[key] = seeds
        elif key.startswith('landmark'):  # landmark validation
            landmarks = stentgraph.StentGraph()
            landmarks.unpack(s[key])
            s[key] = landmarks
    return s
Ejemplo n.º 2
0
def getTopBottomNodesZring(model,nTop=5,nBot=5):
    """Get top and bottom nodes in endurant proximal ring, based on z
    return graphs
    """
    nSorted = np.asarray(sorted(model.nodes(), key=lambda x: x[2])) # sort by z ascending
    nTop = nSorted[:nTop]
    nBot = nSorted[nBot:]
    m_nTop = stentgraph.StentGraph()
    m_nBot = stentgraph.StentGraph()
    for p in nTop:
        m_nTop.add_node(tuple(p.flat))
    for p in nBot:
        m_nBot.add_node(tuple(p.flat))
    
    return m_nTop, m_nBot
Ejemplo n.º 3
0
def pp_to_graph(pp, type='oneEdge'):
    """ PointSet to graph with points connected with edges or as one edge.
    Returns graph. Can be used for centerline output.
    """
    from stentseg.stentdirect import stentgraph
    graph = stentgraph.StentGraph()
    if type == 'oneEdge':
        # add single nodes
        for p in pp:
            p_as_tuple = tuple(p.flat)
            graph.add_node(p_as_tuple)
        # add one path of pp
        pstart = p_as_tuple
        pend = tuple(pp[-1].flat)
        graph.add_edge(pstart, pend, path=pp)
    else:
        for i, p in enumerate(pp[:-1]):
            n1 = tuple(p.flat)
            n2 = tuple(pp[i + 1].flat)
            # create path between nodes as PointSet
            path = PointSet(3, dtype=np.float32)
            for p in [n1, n2]:
                path.append(p)
            graph.add_edge(n1, n2, path=path)

    return graph
def test_weak6_exception():
    """   6 ------ 2 --- 7 ---
                  / \  /  
           4 --- 1---3 --- 5 ---
    
    Rationale: this situation can occur when there is a cross connection within 
    a tetragon. We need to make sure that 1-2 is recognized as part of tetragon,
    independent of 2-3 values, so edge 7-3 must always be considered.
    (solved with *continue* statements in function prune_weak_exception)
    """

    # Test that edge 1-2 is NOT removed
    graph = stentgraph.StentGraph()
    graph.add_edge(1, 2, cost=4, ctvalue=20)
    graph.add_edge(1, 4, cost=1, ctvalue=50)
    graph.add_edge(1, 3, cost=2, ctvalue=50)
    #
    graph.add_edge(2, 3, cost=1, ctvalue=50)
    graph.add_edge(2, 6, cost=1, ctvalue=50)
    graph.add_edge(2, 7, cost=2, ctvalue=50)
    #
    graph.add_edge(3, 7, cost=1, ctvalue=50)
    graph.add_edge(5, 3, cost=2, ctvalue=40)

    #prune_weak(graph, 2, 80)

    # Check graph
    assert graph.number_of_edges() == 8

    # return the graph
    return graph
def test_weak5_exception():
    """   6 ------ 2 --- 7 ---
                  / \    (\) 
           4 --- 1---3 --- 5 ---
    
    Result: Edge 2-3 does not cause problems when strong or weak and high or low CTvalue.
            When 1-2 and 2-3 are weak and of low CT value:
                 2-3 not removed when 5-7 is of high CTvalue -> strong tetragon
                 2-3 is removed when 5-7 is of low CTvalue -> weak tetragon or when 5-7 is not existent
                 
    """

    # Test that edge 1-2 is removed
    graph = stentgraph.StentGraph()
    graph.add_edge(1, 2, cost=4, ctvalue=20)
    graph.add_edge(1, 4, cost=1, ctvalue=50)
    graph.add_edge(1, 3, cost=2, ctvalue=50)
    #
    graph.add_edge(2, 3, cost=3, ctvalue=30)
    graph.add_edge(2, 6, cost=1, ctvalue=50)
    graph.add_edge(2, 7, cost=2, ctvalue=50)
    #
    #graph.add_edge(5, 7, cost=1, ctvalue=50)
    graph.add_edge(5, 3, cost=2, ctvalue=40)

    #prune_weak(graph, 2, 80)

    # Check graph
    assert graph.number_of_edges() == 7  #or 8

    # return the graph
    return graph
def test_redundant1_exception():
    """   6 ------ 2 --- 7 ---
                  / \  /  
           4 --- 1---3 --- 5 ---
    
    Rationale: hooks also form a triangle and sometimes there is also a redundant egde (3-7) part 
    of another triangle connected to the 'hooks' triangle
    """

    # Test that edge 1-3 is NOT removed and 3-7 is removed
    graph = stentgraph.StentGraph()
    graph.add_edge(1, 3, cost=4, ctvalue=20)
    graph.add_edge(1, 4, cost=2, ctvalue=50)
    graph.add_edge(1, 2, cost=1, ctvalue=50)
    #
    graph.add_edge(2, 3, cost=1, ctvalue=50)
    graph.add_edge(2, 6, cost=2, ctvalue=50)
    graph.add_edge(2, 7, cost=2, ctvalue=40)
    #
    graph.add_edge(3, 7, cost=3, ctvalue=30)
    graph.add_edge(5, 3, cost=2, ctvalue=50)

    #prune_weak(graph, 2, 80)

    # Check graph
    assert graph.number_of_edges() == 8

    # return the graph
    return graph
def test_weak4_exception():
    """    5 --- 4 --- 3 ---7
                /     /
        6 --- 1 --- 2 --- 8
    
    """

    # Test that edge 1-2 is NOT removed
    graph = stentgraph.StentGraph()
    graph.add_edge(1, 2, cost=4, ctvalue=20)
    graph.add_edge(1, 4, cost=2, ctvalue=50)
    graph.add_edge(1, 6, cost=1, ctvalue=50)
    #
    graph.add_edge(2, 3, cost=2, ctvalue=35)  # low CT
    graph.add_edge(2, 8, cost=2, ctvalue=50)
    #
    graph.add_edge(3, 4, cost=2, ctvalue=50)
    graph.add_edge(3, 7, cost=2, ctvalue=50)
    graph.add_edge(4, 5, cost=1, ctvalue=50)

    #prune_weak(graph, 2, 80)

    # Check graph
    assert graph.number_of_edges() == 8

    # return the graph
    return graph
Ejemplo n.º 8
0
    def _onFinish(self, event):
        self._finished = True
        print(self.points)
        phase = self.phase
        # store model and pack
        storegraph = self.graph
        self.s_landmarks['landmarks{}'.format(
            phase)] = storegraph.pack()  # s.vol0 etc
        if self.what == 'phases':
            # go to next phase
            self.phase += 10  # next phase
            self.points = []  # empty for new selected points
            self.nodepoints = []
            self.pointindex = 0
            self.vol = self.s['vol{}'.format(self.phase)]  # set new vol
            self.graph = stentgraph.StentGraph()  # new empty graph
            self._updateTitle()
            self._updateTextIndex()
            self.ax.Clear()  # clear the axes. Removing all wobjects
            self.label = DrawModelAxes(self.vol, ax=self.ax)

            # draw vol and graph of 0% in axref
            model = stentgraph.StentGraph()
            model.unpack(self.s_landmarks.landmarks0)
            self.axref.Clear()
            DrawModelAxes(self.s.vol0, model, ax=self.axref)
            self.axref.visible = True
            vv.title(
                'CT Volume 0% for LSPEAS {} with selected landmarks'.format(
                    self.ptcode[7:]))

            self.ax.camera = self.axref.camera

        # === Store landmarks graph ssdf ===
        dirsave = self.dirsave
        ptcode = self.ptcode
        ctcode = self.ctcode
        cropname = self.cropname
        what = self.what
        saveLandmarkModel(self, dirsave, ptcode, ctcode, cropname, what)

        print('Next/Finish was pressed - Landmarks stored')
        return
Ejemplo n.º 9
0
def isosurface_to_graph(pp):
    """ Store isosurface points to graph as seedpoints
    pp is PointSet of isosurface vertex points
    """

    modelnodes = stentgraph.StentGraph()
    # add pp as nodes
    for i, p in enumerate(pp):
        p_as_tuple = tuple(p.flat)
        modelnodes.add_node(p_as_tuple)

    return modelnodes
Ejemplo n.º 10
0
def get_model_struts(model, nstruts=8):
    """Get struts between R1 and R2
    Detects them based on z-orientation and length
    Runs _get_model_hooks 
    """
    from stentseg.stentdirect.stentgraph import _edge_length
    from stentseg.stentdirect import stentgraph
    import numpy as np

    # remove 3rd or 4th ring if there, find self connected
    selfloopnodes = model.nodes_with_selfloops()
    for node in selfloopnodes:
        if model.degree(
                node) == 2:  # node should not be connected to other nodes
            model.remove_node(node)
    # remove hooks if still there
    models = _get_model_hooks(model)
    model_hooks, model_noHooks = models[0], models[1]
    # initialize
    model_h_s = model_hooks.copy()  # struts added to hook model
    model_struts = stentgraph.StentGraph()
    directions = []
    for n1, n2 in model_noHooks.edges():
        e_length = _edge_length(model, n1, n2)
        if (4 < e_length < 12):  # struts OLB21 are 4.5-5.5mm OLB34 9-9.5
            vector = np.subtract(n1, n2)  # nodes, paths in x,y,z
            vlength = np.sqrt(vector[0]**2 + vector[1]**2 + vector[2]**2)
            direction = abs(vector / vlength)
            directions.append([direction, n1, n2])  # direction and nodes


#             print(direction)
    d = np.asarray(directions)  # n x 3 (direction,n1,n2) x 3 (xyz)
    ds = sorted(d[:, 0, 2], reverse=True)  # highest z direction first
    for i in range(nstruts):
        indice = np.where(
            d[:, 0, 2] == ds[i])[0][0]  # [0][0] to get int in array in tuple
        n1 = tuple(d[indice, 1, :])
        n2 = tuple(d[indice, 2, :])
        add_nodes_edge_to_newmodel(model_struts, model, n1, n2)
        add_nodes_edge_to_newmodel(model_h_s, model, n1, n2)

    model_R1R2 = model_noHooks.copy()
    model_R1R2.remove_edges_from(model_struts.edges())
    #     print('************')

    return model_struts, model_hooks, model_R1R2, model_h_s, model_noHooks
Ejemplo n.º 11
0
def makeLandmarkModelDynamic(basedir,
                             ptcode,
                             ctcode,
                             cropname,
                             what='landmarksavgreg',
                             savedir=None):
    """ Make model dynamic with deforms from registration 
        (and store/overwrite to disk)
    """
    #todo: change in default and merge with branch landmarks?
    import pirt
    from stentseg.motion.dynamic import (incorporate_motion_nodes,
                                         incorporate_motion_edges)
    from visvis import ssdf
    import os

    if savedir is None:
        savedir = basedir
    # Load deforms
    s = loadvol(basedir, ptcode, ctcode, cropname, 'deforms')
    deformkeys = []
    for key in dir(s):
        if key.startswith('deform'):
            deformkeys.append(key)
    deforms = [s[key] for key in deformkeys]
    deforms = [pirt.DeformationFieldBackward(*fields) for fields in deforms]
    paramsreg = s.params

    # Load model where landmarks were stored
    # s2 = loadmodel(savedir, ptcode, ctcode, cropname, what)
    fname = '%s_%s_%s_%s.ssdf' % (ptcode, ctcode, cropname, what)
    s2 = ssdf.load(os.path.join(savedir, fname))
    # Turn into graph model
    model = stentgraph.StentGraph()
    model.unpack(s2[what])

    # Combine ...
    incorporate_motion_nodes(model, deforms, s2.origin)
    incorporate_motion_edges(model, deforms, s2.origin)

    # Save back
    filename = '%s_%s_%s_%s.ssdf' % (ptcode, ctcode, cropname, what)
    s2.model = model.pack()
    s2.paramsreg = paramsreg
    ssdf.save(os.path.join(savedir, filename), s2)
    print('saved to disk to {}.'.format(os.path.join(savedir, filename)))
Ejemplo n.º 12
0
def _get_model_hooks(model):
    """Get model hooks
    Return model without hooks and model with hooks only
    """
    import numpy as np
    from stentseg.stentdirect import stentgraph

    # initialize
    model_noHooks = model.copy()
    model_hooks = stentgraph.StentGraph()  # graph for hooks
    hooknodes = list()  # remember nodes that belong to hooks
    for n in sorted(model.nodes()):
        if model.degree(n) == 1:
            neighbour = list(model.edge[n].keys())
            neighbour = neighbour[0]
            add_nodes_edge_to_newmodel(model_hooks, model, n, neighbour)
            hooknodes.append(neighbour)
            model_noHooks.remove_node(
                n)  # this also removes the connecting edge

    return model_hooks, model_noHooks
def test_weak7_exception():
    """    6 --- 4 --- 3 ---7
                /     /
        5 --- 1 --- 2 --- 8
         \     \    /
          ------ 9
    
    Rationale: when multiple tetragons exist for n1-n2, the edge n1-n2 can be 
    removed twice, which causes an error, or be removed wrongly in a first
    instance --> code changed
    """

    # Test that edge 1-2 is NOT removed
    graph = stentgraph.StentGraph()
    graph.add_edge(1, 2, cost=4, ctvalue=20)  # weak
    graph.add_edge(1, 4, cost=2, ctvalue=50)
    graph.add_edge(1, 5, cost=1, ctvalue=50)
    graph.add_edge(1, 9, cost=5, ctvalue=30)  # very weak
    #
    graph.add_edge(2, 3, cost=2, ctvalue=50)
    graph.add_edge(2, 8, cost=2, ctvalue=50)
    graph.add_edge(2, 9, cost=3, ctvalue=50)  # moderate weak
    #
    graph.add_edge(3, 4, cost=2, ctvalue=50)
    graph.add_edge(3, 7, cost=2, ctvalue=50)
    graph.add_edge(4, 6, cost=1, ctvalue=50)
    graph.add_edge(9, 5, cost=3, ctvalue=30)  # moderate weak

    # Check graph
    assert graph.number_of_edges() == 11

    # Prune edges
    #prune_weak_exception(graph, 2, 40)

    # Return the graph
    return graph
Ejemplo n.º 14
0
def get_graph_in_phase(graph, phasenr):
    """ Get position of model in a certain phase
    """
    from stentseg.stentdirect import stentgraph
    import numpy as np

    # initialize
    model_phase = stentgraph.StentGraph()
    for n1, n2 in graph.edges():
        # obtain path and deforms of nodes and edge
        path = graph.edge[n1][n2]['path']
        pathDeforms = graph.edge[n1][n2]['pathdeforms']
        # obtain path in phase
        path_phase = []
        for i, point in enumerate(path):
            pointposition = point + pathDeforms[i][phasenr]
            path_phase.append(pointposition)  # points on path, one phase
        n1_phase, n2_phase = tuple(path_phase[0]), tuple(
            path_phase[-1])  # position of nodes
        model_phase.add_edge(n1_phase,
                             n2_phase,
                             path=np.asarray(path_phase),
                             pathdeforms=np.asarray(pathDeforms))
    return model_phase
Ejemplo n.º 15
0
# These deforms are forward mapping. Turn into DeformationFields.
# Also get the backwards mapping variants (i.e. the inverse deforms).
# The forward mapping deforms should be used to deform meshes (since
# the information is used to displace vertices). The backward mapping
# deforms should be used to deform textures (since they are used in
# interpolating the texture data).
deforms_f = [pirt.DeformationFieldForward(*f) for f in deforms]
deforms_b = [f.as_backward() for f in deforms_f]

# Load the stent model and mesh
s = loadmodel(basedir, ptcode, ctcode, cropname, modelname)
if len(ringnames)==1: # show entire model or 1 ring
    model = s[ringnames[0]]
else:
    # merge ring models into one graph for dynamic visualization
    model = stentgraph.StentGraph()
    for key in ringnames:
        model.add_nodes_from(s[key].nodes(data=True)) # also attributes
        model.add_edges_from(s[key].edges(data=True))

if meshWithColors=='displacement':
    modelmesh = create_mesh_with_abs_displacement(model, radius = 0.7, dim = dimension, motion = motion)
elif meshWithColors=='curvature':
    modelmesh = create_mesh_with_values(model, valueskey='path_curvature_change', radius=0.7)
else:
    modelmesh = create_mesh(model, 1.0)  # Param is thickness

# Load static CT image to add as reference
try:
    s2 = loadvol(basedir, ptcode, ctcode, 'stent', staticref)
except FileNotFoundError:
Ejemplo n.º 16
0
    def __init__(self,
                 dirsave,
                 ptcode,
                 ctcode,
                 cropname,
                 s,
                 what='phases',
                 axes=None,
                 **kwargs):
        """ s is struct from loadvol
        """

        self.fig = vv.figure(1)
        vv.clf()
        self.fig.position = 0.00, 29.00, 1680.00, 973.00
        self.defaultzoom = 0.025  # check current zoom with foo.ax.GetView()
        self.what = what
        self.ptcode = ptcode
        self.dirsave = dirsave
        self.ctcode = ctcode
        self.cropname = cropname
        if self.what == 'phases':
            self.phase = 0
        else:
            self.phase = self.what  # avgreg
        # self.vol = s.vol0
        self.s = s  # s with vol(s)
        self.s_landmarks = vv.ssdf.new()
        self.graph = stentgraph.StentGraph()
        self.points = []  # selected points
        self.nodepoints = []
        self.pointindex = 0  # for selected points
        try:
            self.vol = s.vol0  # when phases
        except AttributeError:
            self.vol = s.vol  # when avgreg

        self.ax = vv.subplot(121)
        self.axref = vv.subplot(122)

        self.label = DrawModelAxes(self.vol,
                                   ax=self.ax)  # label of clicked point
        self.axref.bgcolor = 0, 0, 0
        self.axref.visible = False

        # create axis for buttons
        a_select = vv.Wibject(self.ax)  # on self.ax or fig?
        a_select.position = 0.55, 0.7, 0.6, 0.5  # x, y, w, h

        # Create text objects
        self._labelcurrentIndexT = vv.Label(a_select)  # for text title
        self._labelcurrentIndexT.position = 125, 180
        self._labelcurrentIndexT.text = ' Total selected ='
        self._labelcurrentIndex = vv.Label(a_select)
        self._labelcurrentIndex.position = 225, 180

        # Create Select button
        self._select = False
        self._butselect = vv.PushButton(a_select)
        self._butselect.position = 10, 150
        self._butselect.text = 'Select'

        # Create Back button
        self._back = False
        self._butback = vv.PushButton(a_select)
        self._butback.position = 125, 150
        self._butback.text = 'Undo'

        # Create Next/Save button
        self._finished = False
        self._butclose = vv.PushButton(a_select)
        self._butclose.position = 10, 230
        self._butclose.text = 'Next/Save'

        # # Create Save landmarks button
        # self._save = False
        # self._butsave = vv.PushButton(a_select)
        # self._butsave.position = 125,230
        # self._butsave.text = 'Save|Finished'

        # Create Reset-View button
        self._resetview = False
        self._butresetview = vv.PushButton(a_select)
        self._butresetview.position = 10, 180
        self._butresetview.text = 'Default Zoom'  # back to default zoom

        # bind event handlers
        self.fig.eventClose.Bind(self._onFinish)
        self._butclose.eventPress.Bind(self._onFinish)
        self._butselect.eventPress.Bind(self._onSelect)
        self._butback.eventPress.Bind(self._onBack)
        self._butresetview.eventPress.Bind(self._onView)
        # self._butsave.eventPress.Bind(self._onSave)

        self._updateTextIndex()
        self._updateTitle()
Ejemplo n.º 17
0
def add_nodes_edge_to_newmodel(modelnew, model,n,neighbour):
    """ Get edge and nodes with attributes from model and add to newmodel
    """
    c = model.edge[n][neighbour]['cost']
    ct = model.edge[n][neighbour]['ctvalue']
    p = model.edge[n][neighbour]['path']
    pdeforms = model.edge[n][neighbour]['pathdeforms']
    modelnew.add_node(n, deforms = model.node[n]['deforms'])
    modelnew.add_node(neighbour, deforms = model.node[neighbour]['deforms'])
    modelnew.add_edge(n, neighbour, cost = c, ctvalue = ct, path = p, pathdeforms = pdeforms)
    return


# Create graph for hooks
model_hooks = stentgraph.StentGraph()

remove = True
hooknodes = list() # remember nodes that belong to hooks 
for n in model.nodes():
    if model.degree(n) == 1:
        endnode = n
        neighbour = list(model.edge[n].keys())
        neighbour = neighbour[0]
        add_nodes_edge_to_newmodel(model_hooks,model,n,neighbour)
        hooknodes.append(neighbour)
        if remove == True:
            model.remove_node(n)

# Pop remaining degree 2 nodes
stentgraph.pop_nodes(model) 
Ejemplo n.º 18
0
label = DrawModelAxes(vol, clim=clim, showVol=showVol, axVis=True)

vv.xlabel('x (mm)')
vv.ylabel('y (mm)')
vv.zlabel('z (mm)')
vv.title('CT Volume %i%% for LSPEAS %s  -  %s' % (phase, ptcode[7:], ctcode))

# bind rotate view (a, d, z, x active keys)
fig.eventKeyDown.Bind(lambda event: _utils_GUI.RotateView(event))

# instantiate stentdirect segmenter object
p = getDefaultParams()
sd2 = StentDirect(vol, p)
# initialize
sd2._nodes1 = stentgraph.StentGraph()
nr = 0


def on_key(event):
    if event.key == vv.KEY_CONTROL:
        global nr
        coordinates = np.asarray(label2worldcoordinates(label),
                                 dtype=np.float32)  # x,y,z
        n2 = tuple(coordinates.flat)
        sd2._nodes1.add_node(n2, number=nr)
        print(nr)
        if nr > 0:
            for n in list(sd2._nodes1.nodes()):
                if sd2._nodes1.node[n]['number'] == nr - 1:
                    path = [n2, n]
Ejemplo n.º 19
0
    def __init__(self,
                 ptcode,
                 ctcode,
                 basedir,
                 showVol='mip',
                 meshWithColors=False,
                 motion='amplitude',
                 clim2=(0, 2)):
        """
        Script to show the stent model. [ nellix]
        """

        import os
        import pirt
        import visvis as vv

        from stentseg.utils.datahandling import select_dir, loadvol, loadmodel
        from pirt.utils.deformvis import DeformableTexture3D, DeformableMesh
        from stentseg.stentdirect.stentgraph import create_mesh
        from stentseg.stentdirect import stentgraph
        from stentseg.utils.visualization import show_ctvolume
        from stentseg.motion.vis import create_mesh_with_abs_displacement
        import copy
        from stentseg.motion.dynamic import incorporate_motion_nodes, incorporate_motion_edges
        from lspeas.utils.ecgslider import runEcgSlider
        from stentseg.utils import _utils_GUI

        import numpy as np

        cropname = 'prox'
        # params
        nr = 1
        # motion = 'amplitude'  # amplitude or sum
        dimension = 'xyz'
        showVol = showVol  # MIP or ISO or 2D or None
        clim0 = (0, 2000)
        # clim2 = (0,2)
        motionPlay = 9, 1  # each x ms, a step of x %

        s = loadvol(basedir, ptcode, ctcode, cropname, what='deforms')
        m = loadmodel(basedir, ptcode, ctcode, cropname,
                      'centerline_total_modelavgreg_deforms')
        v = loadmodel(basedir,
                      ptcode,
                      ctcode,
                      cropname,
                      modelname='centerline_total_modelvesselavgreg_deforms')
        s2 = loadvol(basedir, ptcode, ctcode, cropname, what='avgreg')
        vol_org = copy.deepcopy(s2.vol)
        s2.vol.sampling = [
            vol_org.sampling[1], vol_org.sampling[1], vol_org.sampling[2]
        ]
        s2.sampling = s2.vol.sampling
        vol = s2.vol

        # merge models into one for dynamic visualization
        model_total = stentgraph.StentGraph()
        for key in dir(m):
            if key.startswith('model'):
                model_total.add_nodes_from(
                    m[key].nodes(data=True))  # also attributes
                model_total.add_edges_from(m[key].edges(data=True))
        for key in dir(v):
            if key.startswith('model'):
                model_total.add_nodes_from(
                    v[key].nodes(data=True))  # also attributes
                model_total.add_edges_from(v[key].edges(data=True))

        # Load deformations (forward for mesh)
        deformkeys = []
        for key in dir(s):
            if key.startswith('deform'):
                deformkeys.append(key)
        deforms = [s[key] for key in deformkeys]
        deforms = [[field[::2, ::2, ::2] for field in fields]
                   for fields in deforms]

        # These deforms are forward mapping. Turn into DeformationFields.
        # Also get the backwards mapping variants (i.e. the inverse deforms).
        # The forward mapping deforms should be used to deform meshes (since
        # the information is used to displace vertices). The backward mapping
        # deforms should be used to deform textures (since they are used in
        # interpolating the texture data).
        deforms_f = [pirt.DeformationFieldForward(*f) for f in deforms]
        deforms_b = [f.as_backward() for f in deforms_f]

        # Create mesh
        if meshWithColors:
            try:
                modelmesh = create_mesh_with_abs_displacement(model_total,
                                                              radius=0.7,
                                                              dim=dimension,
                                                              motion=motion)
            except KeyError:
                print('Centerline model has no pathdeforms so we create them')
                # use unsampled deforms
                deforms2 = [s[key] for key in deformkeys]
                # deforms as backward for model
                deformsB = [
                    pirt.DeformationFieldBackward(*fields)
                    for fields in deforms2
                ]
                # set sampling to original
                # for i in range(len(deformsB)):
                #         deformsB[i]._field_sampling = tuple(s.sampling)
                # not needed because we use unsampled deforms
                # Combine ...
                incorporate_motion_nodes(model_total, deformsB, s2.origin)
                convert_paths_to_PointSet(model_total)
                incorporate_motion_edges(model_total, deformsB, s2.origin)
                convert_paths_to_ndarray(model_total)
                modelmesh = create_mesh_with_abs_displacement(model_total,
                                                              radius=0.7,
                                                              dim=dimension,
                                                              motion=motion)
        else:
            modelmesh = create_mesh(model_total, 0.7, fullPaths=True)

        ## Start vis
        f = vv.figure(nr)
        vv.clf()
        if nr == 1:
            f.position = 8.00, 30.00, 944.00, 1002.00
        else:
            f.position = 968.00, 30.00, 944.00, 1002.00
        a = vv.gca()
        a.axis.axisColor = 1, 1, 1
        a.axis.visible = False
        a.bgcolor = 0, 0, 0
        a.daspect = 1, 1, -1
        t = vv.volshow(vol, clim=clim0, renderStyle=showVol, axes=a)
        vv.xlabel('x (mm)')
        vv.ylabel('y (mm)')
        vv.zlabel('z (mm)')
        if meshWithColors:
            if dimension == 'xyz':
                dim = '3D'
            vv.title(
                'Model for chEVAS %s  (color-coded %s of movement in %s in mm)'
                % (ptcode[8:], motion, dim))
        else:
            vv.title('Model for chEVAS %s' % (ptcode[8:]))

        colorbar = True
        # Create deformable mesh
        dm = DeformableMesh(a, modelmesh)  # in x,y,z
        dm.SetDeforms(*[list(reversed(deform))
                        for deform in deforms_f])  # from z,y,x to x,y,z
        if meshWithColors:
            dm.clim = clim2
            dm.colormap = vv.CM_JET  #todo: use colormap Viridis or Magma as JET is not linear (https://bids.github.io/colormap/)
            if colorbar:
                vv.colorbar()
            colorbar = False
        else:
            dm.faceColor = 'g'

        # Run mesh
        a.SetLimits()
        # a.SetView(viewringcrop)
        dm.MotionPlay(motionPlay[0],
                      motionPlay[1])  # (10, 0.2) = each 10 ms do a step of 20%
        dm.motionSplineType = 'B-spline'
        dm.motionAmplitude = 0.5

        ## run ecgslider
        ecg = runEcgSlider(dm, f, a, motionPlay)
Ejemplo n.º 20
0
    def __init__(self,
                 ptcode,
                 ctcode,
                 StartPoints,
                 EndPoints,
                 basedir,
                 modelname='modelavgreg'):
        """ with start and endpoints provided, calculate centerline and save as
        ssdf in basedir as model and dynamic model
        """
        #todo: name of dynamic model is now deforms, unclear, should be dynamic
        #import numpy as np
        import visvis as vv
        import numpy as np
        import os
        import copy

        from stentseg.utils import PointSet, _utils_GUI
        from stentseg.utils.centerline import (find_centerline,
                                               points_from_nodes_in_graph,
                                               points_from_mesh,
                                               smooth_centerline)
        from stentseg.utils.datahandling import loadmodel, loadvol
        from stentseg.utils.visualization import show_ctvolume
        from stentseg.utils.picker import pick3d

        stentnr = len(StartPoints)

        cropname = 'prox'
        what = modelname
        what_vol = 'avgreg'
        vismids = True
        m = loadmodel(basedir, ptcode, ctcode, cropname, what)
        s = loadvol(basedir, ptcode, ctcode, cropname, what_vol)
        s.vol.sampling = [s.sampling[1], s.sampling[1], s.sampling[2]]
        s.sampling = s.vol.sampling

        start1 = StartPoints.copy()
        ends = EndPoints.copy()

        from stentseg.stentdirect import stentgraph
        ppp = points_from_nodes_in_graph(m.model)

        allcenterlines = []  # for pp
        allcenterlines_nosmooth = []  # for pp
        centerlines = []  # for stentgraph
        nodes_total = stentgraph.StentGraph()
        for j in range(stentnr):
            if j == 0 or not start1[j] == ends[j - 1]:
                # if first stent or when stent did not continue with this start point
                nodes = stentgraph.StentGraph()
                centerline = PointSet(3)  # empty

            # Find main centerline
            # if j > 3: # for stent with midpoints
            #     centerline1 = find_centerline(ppp, start1[j], ends[j], step= 1,
            #     ndist=10, regfactor=0.5, regsteps=10, verbose=False)

            #else:
            centerline1 = find_centerline(ppp,
                                          start1[j],
                                          ends[j],
                                          step=1,
                                          ndist=10,
                                          regfactor=0.5,
                                          regsteps=1,
                                          verbose=False)
            # centerline1 is a PointSet

            print('Centerline calculation completed')

            # ========= Maaike =======
            smoothfactor = 15  # Mirthe used 2 or 4

            # check if cll continued here from last end point
            if not j == 0 and start1[j] == ends[j - 1]:
                # yes we continued
                ppart = centerline1[:
                                    -1]  # cut last but do not cut first point as this is midpoint
            else:
                # do not use first points, as they are influenced by user selected points
                ppart = centerline1[1:-1]

            for p in ppart:
                centerline.append(p)

            # if last stent or stent does not continue with next start-endpoint
            if j == stentnr - 1 or not ends[j] == start1[j + 1]:
                # store non-smoothed for vis
                allcenterlines_nosmooth.append(centerline)
                pp = smooth_centerline(centerline, n=smoothfactor)
                # add pp to list
                allcenterlines.append(pp)  # list with PointSet per centerline
                self.allcenterlines = allcenterlines

                # add pp as nodes
                for i, p in enumerate(pp):
                    p_as_tuple = tuple(p.flat)
                    nodes.add_node(p_as_tuple)
                    nodes_total.add_node(p_as_tuple)
                # add pp as one edge so that pathpoints are in fixed order
                pstart = tuple(pp[0].flat)
                pend = tuple(pp[-1].flat)
                nodes.add_edge(pstart, pend, path=pp)
                nodes_total.add_edge(pstart, pend, path=pp)
                # add final centerline nodes model to list
                centerlines.append(nodes)

            # ========= Maaike =======

        ## Store segmentation to disk

        # Build struct
        s2 = vv.ssdf.new()
        s2.sampling = s.sampling
        s2.origin = s.origin
        s2.stenttype = m.stenttype
        s2.croprange = m.croprange
        for key in dir(m):
            if key.startswith('meta'):
                suffix = key[4:]
                s2['meta' + suffix] = m['meta' + suffix]
        s2.what = what
        s2.params = s.params  #reg
        s2.paramsseeds = m.params
        s2.stentType = 'nellix'
        s2.StartPoints = StartPoints
        s2.EndPoints = EndPoints
        # keep centerlines as pp also [Maaike]
        s2.ppallCenterlines = allcenterlines
        for k in range(len(allcenterlines)):
            suffix = str(k)
            pp = allcenterlines[k]
            s2['ppCenterline' + suffix] = pp

        s3 = copy.deepcopy(s2)
        s3['model'] = nodes_total.pack()

        # Store model for each centerline
        for j in range(len(centerlines)):
            suffix = str(j)
            model = centerlines[j]
            s2['model' + suffix] = model.pack()

        # Save model with seperate centerlines.
        filename = '%s_%s_%s_%s.ssdf' % (ptcode, ctcode, cropname,
                                         'centerline_' + what)
        vv.ssdf.save(os.path.join(basedir, ptcode, filename), s2)
        print('saved to disk as {}.'.format(filename))

        # Save model with combined centerlines
        filename = '%s_%s_%s_%s.ssdf' % (ptcode, ctcode, cropname,
                                         'centerline_total_' + what)
        vv.ssdf.save(os.path.join(basedir, ptcode, filename), s3)
        print('saved to disk as {}.'.format(filename))

        # remove intermediate centerline points
        # start1 = map(tuple, start1)
        # ends = map(tuple, ends)
        startpoints_clean = copy.deepcopy(start1)
        endpoints_clean = copy.deepcopy(ends)
        duplicates = list(set(start1) & set(ends))
        for i in range(len(duplicates)):
            startpoints_clean.remove(duplicates[i])
            endpoints_clean.remove(duplicates[i])

        #Visualize
        f = vv.figure(10)
        vv.clf()
        a1 = vv.subplot(121)
        a1.daspect = 1, 1, -1

        vv.plot(ppp, ms='.', ls='', alpha=0.6, mw=2)
        for j in range(len(startpoints_clean)):
            vv.plot(PointSet(list(startpoints_clean[j])),
                    ms='.',
                    ls='',
                    mc='g',
                    mw=20)  # startpoint green
            vv.plot(PointSet(list(endpoints_clean[j])),
                    ms='.',
                    ls='',
                    mc='r',
                    mw=20)  # endpoint red
        for j in range(len(allcenterlines)):
            vv.plot(allcenterlines[j], ms='.', ls='', mw=10, mc='y')
        vv.title('Centerlines and seed points')
        vv.xlabel('x (mm)')
        vv.ylabel('y (mm)')
        vv.zlabel('z (mm)')
        # for j in range(len(allcenterlines_nosmooth)):
        #     vv.plot(allcenterlines_nosmooth[j], ms='o', ls='', mw=10, mc='c', alpha=0.6)

        a2 = vv.subplot(122)
        a2.daspect = 1, 1, -1

        vv.plot(ppp, ms='.', ls='', alpha=0.6, mw=2)
        # vv.volshow(s.vol, clim=clim, renderStyle = 'mip')
        t = show_ctvolume(s.vol,
                          axis=a2,
                          showVol='ISO',
                          clim=(0, 2500),
                          isoTh=250,
                          removeStent=False,
                          climEditor=True)
        label = pick3d(vv.gca(), s.vol)
        for j in range(len(startpoints_clean)):
            vv.plot(PointSet(list(startpoints_clean[j])),
                    ms='.',
                    ls='',
                    mc='g',
                    mw=20,
                    alpha=0.6)  # startpoint green
            vv.plot(PointSet(list(endpoints_clean[j])),
                    ms='.',
                    ls='',
                    mc='r',
                    mw=20,
                    alpha=0.6)  # endpoint red
        for j in range(len(allcenterlines)):
            vv.plot(allcenterlines[j], ms='o', ls='', mw=10, mc='y', alpha=0.6)

        # show midpoints (e.g. duplicates)
        if vismids:
            for p in duplicates:
                vv.plot(p[0], p[1], p[2], mc='m', ms='o', mw=10, alpha=0.6)

        a2.axis.visible = False

        vv.title('Centerlines and seed points')

        a1.camera = a2.camera

        f.eventKeyDown.Bind(
            lambda event: _utils_GUI.RotateView(event, [a1, a2]))
        f.eventKeyDown.Bind(
            lambda event: _utils_GUI.ViewPresets(event, [a1, a2]))

        # Pick node for midpoint to redo get_centerline
        self.pickedCLLpoint = _utils_GUI.Event_pick_graph_point(
            nodes_total, s.vol, label, nodesOnly=True)  # x,y,z
        # use key p to select point

        #===============================================================================
        vv.figure(11)
        vv.gca().daspect = 1, 1, -1
        t = show_ctvolume(s.vol,
                          showVol='ISO',
                          clim=(0, 2500),
                          isoTh=250,
                          removeStent=False,
                          climEditor=True)
        label2 = pick3d(vv.gca(), s.vol)
        for j in range(len(startpoints_clean)):
            vv.plot(PointSet(list(startpoints_clean[j])),
                    ms='.',
                    ls='',
                    mc='g',
                    mw=20,
                    alpha=0.6)  # startpoint green
            vv.plot(PointSet(list(endpoints_clean[j])),
                    ms='.',
                    ls='',
                    mc='r',
                    mw=20,
                    alpha=0.6)  # endpoint red
        vv.xlabel('x (mm)')
        vv.ylabel('y (mm)')
        vv.zlabel('z (mm)')
        #===============================================================================

        ## Make model dynamic (and store/overwrite to disk)
        import pirt
        from stentseg.motion.dynamic import incorporate_motion_nodes, incorporate_motion_edges

        # Load deforms
        filename = '%s_%s_%s_%s.ssdf' % (ptcode, ctcode, cropname, 'deforms')
        s1 = vv.ssdf.load(os.path.join(basedir, ptcode, filename))
        deformkeys = []
        for key in dir(s1):
            if key.startswith('deform'):
                deformkeys.append(key)
        deforms = [s1[key] for key in deformkeys]
        deforms = [
            pirt.DeformationFieldBackward(*fields) for fields in deforms
        ]
        for i in range(len(deforms)):
            deforms[i]._field_sampling = tuple(s1.sampling)
        paramsreg = s1.params

        # Load model
        s2 = loadmodel(basedir, ptcode, ctcode, cropname, 'centerline_' + what)
        s3 = loadmodel(basedir, ptcode, ctcode, cropname,
                       'centerline_total_' + what)

        # Combine ...
        for key in dir(s2):
            if key.startswith('model'):
                incorporate_motion_nodes(s2[key], deforms, s.origin)
                incorporate_motion_edges(s2[key], deforms, s.origin)
                model = s2[key]
                s2[key] = model.pack()
        # Combine ...
        for key in dir(s3):
            if key.startswith('model'):
                incorporate_motion_nodes(s3[key], deforms, s.origin)
                incorporate_motion_edges(s3[key], deforms, s.origin)
                model = s3[key]
                s3[key] = model.pack()

        # Save
        s2.paramsreg = paramsreg
        filename = '%s_%s_%s_%s.ssdf' % (ptcode, ctcode, cropname,
                                         'centerline_' + what + '_deforms')
        vv.ssdf.save(os.path.join(basedir, ptcode, filename), s2)
        print('saved to disk as {}.'.format(filename))

        # Save
        s3.paramsreg = paramsreg
        filename = '%s_%s_%s_%s.ssdf' % (
            ptcode, ctcode, cropname, 'centerline_total_' + what + '_deforms')
        vv.ssdf.save(os.path.join(basedir, ptcode, filename), s3)
        print('saved to disk as {}.'.format(filename))
Ejemplo n.º 21
0
def on_key(event):
    """KEY commands for user interaction
    UP/DOWN = show/hide nodes
    ENTER   = restore edge [select 2 nodes]
    DELETE  = remove edge [select 2 ndoes] or pop node [select 1 node]'
    ALT     = clean graph: pop, crossings, corner
    ESCAPE  = FINISH: refine, smooth
    """
    global node_points
    global nodes3copy
    if event.key == vv.KEY_DOWN:
        # hide nodes
        t1.visible = False
        t2.visible = False
        t3.visible = False
        for node_point in node_points:
            node_point.visible = False
    if event.key == vv.KEY_UP:
        # show nodes
        for node_point in node_points:
            node_point.visible = True
    if event.key == vv.KEY_ENTER:
        # restore edge
        assert len(selected_nodes) == 2
        select1 = selected_nodes[0].node
        select2 = selected_nodes[1].node
        c = sd._nodes2.edge[select1][select2]['cost']
        ct = sd._nodes2.edge[select1][select2]['ctvalue']
        p = sd._nodes2.edge[select1][select2]['path']
        sd._nodes3.add_edge(select1, select2, cost=c, ctvalue=ct, path=p)
        l = stentgraph._edge_length(sd._nodes3, select1, select2)
        # Visualize restored edge and deselect nodes
        selected_nodes[1].faceColor = 'b'
        selected_nodes[0].faceColor = 'b'
        selected_nodes.clear()
        t1.text = 'Edge ctvalue: \b{%1.2f HU}' % ct
        t2.text = 'Edge cost: \b{%1.7f }' % c
        t3.text = 'Edge length: \b{%1.2f mm}' % l
        t1.visible = True
        t2.visible = True
        t3.visible = True
        view = a3.GetView()
        pp = Pointset(p)  # visvis meshes do not work with PointSet
        line = vv.solidLine(pp, radius=0.2)
        line.faceColor = 'g'
        a3.SetView(view)
    if event.key == vv.KEY_DELETE:
        if len(selected_nodes) == 2:
            # remove edge
            select1 = selected_nodes[0].node
            select2 = selected_nodes[1].node
            c = sd._nodes3.edge[select1][select2]['cost']
            ct = sd._nodes3.edge[select1][select2]['ctvalue']
            p = sd._nodes3.edge[select1][select2]['path']
            l = stentgraph._edge_length(sd._nodes3, select1, select2)
            sd._nodes3.remove_edge(select1, select2)
            # visualize removed edge, show keys and deselect nodes
            selected_nodes[1].faceColor = 'b'
            selected_nodes[0].faceColor = 'b'
            selected_nodes.clear()
            t1.text = 'Edge ctvalue: \b{%1.2f HU}' % ct
            t2.text = 'Edge cost: \b{%1.7f }' % c
            t3.text = 'Edge length: \b{%1.2f mm}' % l
            t1.visible = True
            t2.visible = True
            t3.visible = True
            view = a3.GetView()
            pp = Pointset(p)
            line = vv.solidLine(pp, radius=0.2)
            line.faceColor = 'r'
            a3.SetView(view)
        if len(selected_nodes) == 1:
            # pop node
            select1 = selected_nodes[0].node
            stentgraph._pop_node(sd._nodes3, select1)  # asserts degree == 2
            selected_nodes[0].faceColor = 'w'
            selected_nodes.clear()
    if event.key == vv.KEY_ALT:
        #backup to restore
        nodes3copy = sd._nodes3.copy()
        # clean nodes
        if stentType == 'anacondaRing':
            stentgraph.add_nodes_at_crossings(sd._nodes3)
        stentgraph.pop_nodes(
            sd._nodes3)  # pop before corner detect or angles can not be found
        stentgraph.add_corner_nodes(sd._nodes3,
                                    th=sd._params.graph_angleVector,
                                    angTh=sd._params.graph_angleTh)
        stentgraph.pop_nodes(
            sd._nodes3
        )  # because removing edges/add nodes can create degree 2 nodes
        stentgraph.prune_tails(sd._nodes3, sd._params.graph_trimLength)
        stentgraph.prune_clusters(sd._nodes3,
                                  3)  #remove residual nodes/clusters
        # visualize result
        view = a3.GetView()
        a3.Clear()
        DrawModelAxes(vol,
                      sd._nodes3,
                      a3,
                      clim=clim,
                      showVol=showVol,
                      mw=8,
                      lw=0.2,
                      climEditor=False)
        node_points = _utils_GUI.interactive_node_points(sd._nodes3, scale=0.6)
        _utils_GUI.node_points_callbacks(node_points,
                                         selected_nodes,
                                         pick=False)
        a3.SetView(view)
        print('----Press ESCAPE to FINISH model----')
    if event.text == 'u':
        # undo and restore nodes3
        sd._nodes3 = nodes3copy
        view = a3.GetView()
        a3.Clear()
        DrawModelAxes(vol,
                      sd._nodes3,
                      a3,
                      clim=clim,
                      showVol=showVol,
                      mw=8,
                      lw=0.2,
                      climEditor=False)
        node_points = _utils_GUI.interactive_node_points(sd._nodes3, scale=0.6)
        _utils_GUI.node_points_callbacks(node_points,
                                         selected_nodes,
                                         pick=False)
        a3.SetView(view)
    if event.key == vv.KEY_ESCAPE:
        #backup to restore
        nodes3copy = sd._nodes3.copy()
        # ESCAPE will FINISH model
        stentgraph.pop_nodes(sd._nodes3)
        sd._nodes3 = sd._RefinePositions(sd._nodes3)  # subpixel locations
        stentgraph.smooth_paths(sd._nodes3, 4)
        # Create mesh and visualize
        view = a3.GetView()
        a3.Clear()
        DrawModelAxes(vol,
                      sd._nodes3,
                      a3,
                      meshColor='g',
                      clim=clim,
                      showVol=showVol,
                      lc='w',
                      mw=8,
                      lw=0.2,
                      climEditor=False)
        node_points = _utils_GUI.interactive_node_points(sd._nodes3, scale=0.6)
        _utils_GUI.node_points_callbacks(node_points,
                                         selected_nodes,
                                         pick=False)
        a3.SetView(view)
        print(
            '----DO NOT FORGET TO SAVE THE MODEL TO DISK; RUN _SAVE_SEGMENTATION----'
        )
    if event.text == 'r':
        # restore this node
        pickedNode = _utils_GUI.snap_picked_point_to_graph(
            sd._nodes2, vol, label, nodesOnly=True)  # x,y,z
        sd._nodes3.add_node(pickedNode)
        view = a3.GetView()
        a3.Clear()
        DrawModelAxes(vol,
                      sd._nodes3,
                      a3,
                      clim=clim,
                      showVol=showVol,
                      mw=8,
                      lw=0.2,
                      climEditor=False)
        node_points = _utils_GUI.interactive_node_points(sd._nodes3, scale=0.6)
        _utils_GUI.node_points_callbacks(node_points,
                                         selected_nodes,
                                         pick=False)
        a3.SetView(view)
    if event.text == 's':
        # additional smooth
        stentgraph.smooth_paths(sd._nodes3, 2)
        view = a3.GetView()
        a3.Clear()
        DrawModelAxes(vol,
                      sd._nodes3,
                      a3,
                      clim=clim,
                      showVol=showVol,
                      mw=8,
                      lw=0.2,
                      climEditor=False)
        node_points = _utils_GUI.interactive_node_points(sd._nodes3, scale=0.6)
        _utils_GUI.node_points_callbacks(node_points,
                                         selected_nodes,
                                         pick=False)
        a3.SetView(view)
    if event.text == 'e':
        # smooth selected edge
        edgegraph = stentgraph.StentGraph()  #empty graph
        select1 = selected_nodes[0].node
        select2 = selected_nodes[1].node
        edge_info = sd._nodes3.edge[select1][select2]
        edgegraph.add_edge(select1, select2, **edge_info)
        stentgraph.smooth_paths(edgegraph, 4)
        sd._nodes3.edge[select1][select2]['path'] = edgegraph.edge[select1][
            select2]['path']
        view = a3.GetView()
        a3.Clear()
        DrawModelAxes(vol,
                      sd._nodes3,
                      a3,
                      clim=clim,
                      showVol=showVol,
                      mw=8,
                      lw=0.2,
                      climEditor=False)
        node_points = _utils_GUI.interactive_node_points(sd._nodes3, scale=0.6)
        _utils_GUI.node_points_callbacks(node_points,
                                         selected_nodes,
                                         pick=False)
        # see if node_points are still selected to color them red
        for node_point in node_points:
            node_point.visible = True
            for i, node in enumerate(selected_nodes):
                if node_point.node == node.node:
                    selected_nodes[i] = node_point
                    node_point.faceColor = (1, 0, 0)
        a3.SetView(view)
    if event.text == 'w':
        for n in selected_nodes:
            n.faceColor = 'b'
        selected_nodes.clear()
    if event.text == 'q':
        view = a3.GetView()
        _utils_GUI.interactiveClusterRemoval(sd._nodes3)
        a3.SetView(view)