Beispiel #1
0
def model2mesh(basedir,savedir,ptcode,ctcode,cropname,modelname='modelavgreg'):
    from stentseg.stentdirect.stentgraph import create_mesh
    
    # formats that can be saved to: .obj .stl .ssdf .bsdf 
    filename = '%s_%s_%s_%s.stl' % (ptcode, ctcode, cropname, modelname)
    # Load the stent model and mesh
    s = loadmodel(basedir, ptcode, ctcode, cropname, modelname)
    model = s.model
    mesh = create_mesh(model, 0.4)  # Param is thickness (with 0.4 -> ~0.75mm diam)
    mesh._vertices[:,-1] *= -1 # flip z, negative in original dicom
    mesh._normals[:,-1] *= -1  # flip also normals to change front face and back face along
    vv.meshWrite(os.path.join(savedir, filename),mesh)
Beispiel #2
0
def DrawModelAxes(vol,
                  graph=None,
                  ax=None,
                  axVis=False,
                  meshColor=None,
                  getLabel=False,
                  mc='b',
                  lc='g',
                  mw=7,
                  lw=0.6,
                  **kwargs):
    """ Draw model with volume with axes set
    ax = axes to draw (a1 or a2 or a3); graph = sd._nodes1 or 2 or 3
    meshColor = None or faceColor e.g. 'g'
    """
    #todo: prevent TypeError: draw() got an unexpected keyword argument mc/lc when not given as required variable
    #todo: *args voor vol in drawModelAxes of **kwargs[key] in functies hieronder
    if ax is None:
        ax = vv.gca()
    ax.MakeCurrent()
    ax.daspect = 1, 1, -1
    ax.axis.axisColor = 1, 1, 1
    ax.bgcolor = 0, 0, 0
    ax.axis.visible = axVis
    vv.xlabel('x (mm)')
    vv.ylabel('y (mm)')
    vv.zlabel('z (mm)')
    if graph is None:
        show_ctvolume(vol, graph, axis=ax, removeStent=False, **kwargs)
        label = pick3d(vv.gca(), vol)
        return label
    if hasattr(graph, 'number_of_edges'):
        if graph.number_of_edges(
        ) == 0:  # get label from picked seeds sd._nodes1
            show_ctvolume(vol, graph, axis=ax, **kwargs)
            label = pick3d(vv.gca(), vol)
            graph.Draw(mc=mc, lc=lc)
            return label
    if not meshColor is None:
        bm = create_mesh(graph, 0.5)  # (argument is strut tickness)
        m = vv.mesh(bm)
        m.faceColor = meshColor  # 'g'
    show_ctvolume(vol, graph, axis=ax, **kwargs)
    graph.Draw(mc=mc, lc=lc)
    if getLabel == True:
        label = pick3d(vv.gca(), vol)
        return label
    else:
        pick3d(vv.gca(), vol)
        return
Beispiel #3
0
# renal_left, renal_right = foo.readRenalsExcel(sheet_renals_obs, ptcode, ctcode1)
# renal1 = renal_left

## Load (dynamic) stent models, vessel, ct
# Load static CT image to add as reference
s = loadvol(basedir, ptcode, ctcode1, cropname, 'avgreg')
vol1 = s.vol
if ctcode2:
    s = loadvol(basedir, ptcode, ctcode2, cropname, 'avgreg')
    vol2 = s.vol

# load stent model
s2 = loadmodel(basedir, ptcode, ctcode1, cropname, modelname)
model1 = s2.model
modelmesh1 = create_mesh(model1, meshradius)
if ctcode2:
    s2 = loadmodel(basedir, ptcode, ctcode2, cropname, modelname)
    model2 = s2.model
    modelmesh2 = create_mesh(model2, meshradius)

# Load vessel mesh (output Mimics)
vessel1 = loadmesh(basedirstl, ptcode, vesselname1)  #inverts Z
if ctcode2:
    vessel2 = loadmesh(basedirstl, ptcode, vesselname2)  #inverts Z
# get pointset from STL
ppvessel1 = points_from_mesh(vessel1, invertZ=False)  # removes duplicates
if ctcode2:
    ppvessel2 = points_from_mesh(vessel2, invertZ=False)  # removes duplicates

## Create centerline: input start/end
Beispiel #4
0
# Instantiate stentdirect segmenter object
#sd = StentDirect_old(vol, p)
sd = StentDirect(vol, p)

# Perform the three steps of stentDirect
sd.Step1()
sd.Step2()
# sd._nodes2 = stentgraph.StentGraph()
# sd._nodes2.Unpack(ssdf.load('/home/almar/tmp.ssdf'))
sd.Step3()

# Create a mesh object for visualization (argument is strut tickness)
if hasattr(sd._nodes3, 'CreateMesh'):
    bm = sd._nodes3.CreateMesh(0.6)  # old
else:
    bm = create_mesh(sd._nodes3, 0.6) # new


# Create figue
vv.figure(2); vv.clf()

# Show volume and segmented stent as a graph
a1 = vv.subplot(131)
t = vv.volshow(vol)
t.clim = 0, 3000
#sd._nodes1.Draw(mc='g', mw = 6)    # draw seeded nodes
#sd._nodes2.Draw(mc='g')            # draw seeded and MCP connected nodes

# Show cleaned up
a2 = vv.subplot(132)
sd._nodes3.Draw(mc='g', lc='b')
Beispiel #5
0
                        r'F:\LSPEAS_ssdf_backup', r'G:\LSPEAS_ssdf_backup')
    
    # Select dataset to register
    ptcode = 'LSPEAS_021'
    ctcode = 'discharge'
    cropname = 'ring'
    modelname = 'modelavgreg'
    
    # Load static CT image to add as reference
    s = loadvol(basedir, ptcode, ctcode, cropname, 'avgreg')
    vol = s.vol
    
    # Load the stent model and mesh
    s2 = loadmodel(basedir, ptcode, ctcode, cropname, modelname)
    model = s2.model
    modelmesh = create_mesh(model, 0.6)  # Param is thickness
    
    showAxis = True  # True or False
    showVol  = 'MIP'  # MIP or ISO or 2D or None
    ringpart = True # True; False
    nstruts = 8
    clim0  = (0,3000)
    # clim0 = -550,500
    clim2 = (0,4)
    radius = 0.07
    dimensions = 'xyz'
    isoTh = 250

    ## Visualize with GUI
    f = vv.figure(3); vv.clf()
    f.position = 968.00, 30.00,  944.00, 1002.00
def on_key(event):
    global node_points 
    if event.key == vv.KEY_DOWN:
        # hide nodes and labels
        t1.visible, t2.visible, t3.visible = False, False, False
        t4.visible, t5.visible, t6.visible = False, False, False
        for node_point in node_points:
            node_point.visible = False
    if event.key == vv.KEY_UP:
        # show nodes and labels
        t1.visible, t2.visible, t3.visible = True, True, True
        t4.visible, t5.visible, t6.visible = True, True, True
        for node_point in node_points:
            node_point.visible = True
    if event.text == 'n':
        # add clickable point: point on graph closest to picked point (SHIFT+R-click )
        view = a.GetView()
        for node_point in node_points:
            node_point.visible = False
        snapOut = _utils_GUI.snap_picked_point_to_graph(model, vol, label) # x,y,z
        pickedOnGraph = snapOut[0]
        n1, n2 = snapOut[1]
        pickedOnGraphIndex = snapOut[2]
        pickedOnGraphDeforms = model.edge[n1][n2]['pathdeforms'][pickedOnGraphIndex]
        model.add_node(pickedOnGraph, deforms=pickedOnGraphDeforms)
        node_points = _utils_GUI.interactive_node_points(model, scale=0.7)
        _utils_GUI.node_points_callbacks(node_points, selected_nodes, t0=t0)
        # visualize
        # pickedOnGraph_sphere = vv.solidSphere(translation = (pickedOnGraph), scaling = (scale,scale,scale))
        point = vv.plot(pickedOnGraph[0], pickedOnGraph[1], pickedOnGraph[2], 
                        mc = 'y', ms = 'o', mw = 9, alpha=0.5)
        a.SetView(view)
    if event.key == vv.KEY_ENTER:
        assert len(selected_nodes) == 2 or 3 or 4
        # Node_to_node analysis
        if len(selected_nodes) == 2:
            # get nodes
            selectn1 = selected_nodes[0].node
            selectn2 = selected_nodes[1].node
            # get index of nodes which are in fixed order
            n1index = selected_nodes[0].nr
            n2index = selected_nodes[1].nr
            nindex = [n1index, n2index]
            # get deforms of nodes
            n1Deforms = model.node[selectn1]['deforms']
            n2Deforms = model.node[selectn2]['deforms']
            # get pulsatility
            output = point_to_point_pulsatility(selectn1, n1Deforms, selectn2, n2Deforms)
            # update labels
            t1.text = '\b{Node pair}: %i - %i' % (nindex[0], nindex[1])
            t2.text = 'Node-to-node Min: %1.2f mm' % output[0][0]
            t3.text = 'Node-to-node Max: %1.2f mm' % output[4][0]
            t4.text = 'Node-to-node Median: %1.2f mm' % output[2]
            t5.text = 'Node-to-node Q1 and Q3: %1.2f | %1.2f mm' % (output[1], output[3])
            t6.text = '\b{Node-to-node Pulsatility: %1.2f mm}' % (output[5][0] )
            t1.visible, t2.visible, t3.visible = True, True, True
            t4.visible, t5.visible, t6.visible = True, True, True
            # Store output including index/nr of nodes
            output.insert(0, [n1index]) # at the start
            output.insert(1, [n2index])
            output[8].insert(0, [n1index])
            output[9].insert(0, [n2index])
            if output not in storeOutput:
                storeOutput.append(output)
        # Midpoint_to_node analysis
        if len(selected_nodes)== 3:
            # find the edge selected to get midpoint
            selected_nodes2 = selected_nodes.copy()
            for node1 in selected_nodes:
                selected_nodes2.remove(node1) # check combination once and not to self
                for node2 in selected_nodes2:
                    if model.has_edge(node1.node, node2.node):
                        # get midpoint of edge and its deforms
                        output = get_midpoint_deforms_edge(model, node1.node, node2.node)
                        break  # edge found, to first for loop
            # get index of nodepair and midpoint and its deforms
            nodepair1 = output[0]
            midpoint1IndexPath = output[1]
            midpoint1 = output[2]
            midpoint1Deforms = output[3]
            # get node
            for i, node in enumerate(selected_nodes):
                if node.nr not in nodepair1:
                    n3 = node
                    break
            # get deforms for node
            n3Deforms = model.node[n3.node]['deforms']
            # get pulsatility
            # first selected first in output
            if i > 0: # single node was not selected first
                output2 = point_to_point_pulsatility(midpoint1, 
                            midpoint1Deforms, n3.node, n3Deforms)
            else:
                output2 = point_to_point_pulsatility(n3.node, n3Deforms,
                            midpoint1, midpoint1Deforms)
            # visualize midpoint
            view = a.GetView()
            point = vv.plot(midpoint1[0], midpoint1[1], midpoint1[2], 
                            mc = 'm', ms = 'o', mw = 8, alpha=0.5)
            a.SetView(view)
            # update labels
            t1.text = '\b{Node pairs}: (%i %i) - (%i)' % (nodepair1[0],nodepair1[1],n3.nr)
            t2.text = 'Midpoint-to-node Min: %1.2f mm' % output2[0][0]
            t3.text = 'Midpoint-to-node Max: %1.2f mm' % output2[4][0]
            t4.text = 'Midpoint-to-node Median: %1.2f mm' % output2[2]
            t5.text = 'Midpoint-to-node Q1 and Q3: %1.2f | %1.2f mm' % (output2[1], output2[3])
            t6.text = '\b{Midpoint-to-node Pulsatility: %1.2f mm}' % (output2[5][0])
            t1.visible, t2.visible, t3.visible = True, True, True
            t4.visible, t5.visible, t6.visible = True, True, True
            # Store output including index nodes
            if i > 0:
                output2.insert(0, nodepair1) # at the start
                output2.insert(1, [n3.nr])
                output2[8].insert(0, midpoint1IndexPath)
                output2[9].insert(0, [n3.nr])
            else:
                output2.insert(0, [n3.nr]) # at the start
                output2.insert(1, nodepair1)
                output2[8].insert(0, [n3.nr])
                output2[9].insert(0, midpoint1IndexPath)
            if output2 not in storeOutput:
                storeOutput.append(output2)
        # Midpoint_to_midpoint analysis
        if len(selected_nodes) == 4:
            outputs = list()
            # get midpoints for the two edges
            # get nodepairs from order selected
            for i in (0,2):
                n1 = selected_nodes[i].node
                n2 = selected_nodes[i+1].node
                assert model.has_edge(n1, n2)
                # get midpoint of edge and its deforms
                output = get_midpoint_deforms_edge(model, n1, n2)
                midpoint = output[2]
                # store for both edges
                outputs.append(output)
                # visualize midpoint
                view = a.GetView()
                point = vv.plot(midpoint[0], midpoint[1], midpoint[2], 
                                mc = 'm', ms = 'o', mw = 8, alpha=0.5)
                a.SetView(view)
            assert len(outputs) == 2 # two midpoints should be found
            # get midpoints and deforms
            nodepair1 = outputs[0][0]
            midpoint1IndexPath = outputs[0][1]
            midpoint1 = outputs[0][2]
            midpoint1Deforms = outputs[0][3]
            nodepair2 = outputs[1][0]
            midpoint2IndexPath = outputs[1][1]
            midpoint2 = outputs[1][2]
            midpoint2Deforms = outputs[1][3]
            # get pulsatility midp to midp
            output2 = point_to_point_pulsatility(midpoint1, 
                                midpoint1Deforms, midpoint2, midpoint2Deforms)
            # # get max pulsatility between points on the paths
            # outputmaxP.append(edge_to_edge_max_pulsatility(model, nodepair1, nodepair2))
            # update labels
            t1.text = '\b{Node pairs}: (%i %i) - (%i %i)' % (nodepair1[0], nodepair1[1],
                                                            nodepair2[0], nodepair2[1])
            t2.text = 'Midpoint-to-midpoint Min: %1.2f mm' % output2[0][0]
            t3.text = 'Midpoint-to-midpoint Max: %1.2f mm' % output2[4][0]
            t4.text = 'Midpoint-to-midpoint Median: %1.2f mm' % output2[2]
            t5.text = 'Midpoint-to-midpoint Q1 and Q3: %1.2f | %1.2f mm' % (output2[1], output2[3])
            t6.text = '\b{Midpoint-to-midpoint Pulsatility: %1.2f mm}' % (output2[5][0])
            t1.visible, t2.visible, t3.visible = True, True, True
            t4.visible, t5.visible, t6.visible = True, True, True
            # Store output including nodepairs of the midpoints
            output2.insert(0, nodepair1) # indices at the start
            output2.insert(1, nodepair2)
            output2[8].insert(0, midpoint1IndexPath)
            output2[9].insert(0, midpoint2IndexPath)
            if output2 not in storeOutput:
                storeOutput.append(output2)
        # Visualize analyzed nodes and deselect
        for node in selected_nodes:
            node.faceColor = (0,1,0,0.8) #  # make green when analyzed
        selected_nodes.clear()
    if event.key == vv.KEY_ESCAPE:
        # FINISH, STORE TO EXCEL
        # visualize
        view = a.GetView()
        t = vv.volshow(vol, clim=clim, renderStyle='mip')
        # show mesh of model without deform coloring
        modelmesh = create_mesh(model, 0.4)  # Param is thickness
        m = vv.mesh(modelmesh)
        m.faceColor = (0,1,0,1) # green
        a.SetView(view)
        # Store to EXCEL
        storeOutputToExcel(storeOutput,exceldir)
        for node_point in node_points:
            node_point.visible = False # show that store is ready
Beispiel #7
0
def interactiveClusterRemoval(graph,
                              radius=0.7,
                              axVis=False,
                              faceColor=(0.5, 1.0, 0.3),
                              selectColor=(1.0, 0.3, 0.3)):
    """ showGraphAsMesh(graph, radius=0.7, 
                faceColor=(0.5,1.0,0.3), selectColor=(1.0,0.3, 0.3) )
    
    Manual delete clusters in the graph. Show the given graph as a mesh, or to 
    be more precize as a set of meshes representing the clusters of the graph. 
    By holding the mouse over a mesh, it can be selected, after which it can be 
    deleted by pressing delete. Use sd._nodes3 for graph when in segmentation.
    
    Returns the axes in which the meshes are drawn.
    
    """
    import visvis as vv
    import networkx as nx
    from stentseg.stentdirect import stentgraph
    from stentseg.stentdirect.stentgraph import create_mesh

    # Get clusters of nodes
    clusters = list(nx.connected_components(graph))

    # Build meshes
    meshes = []
    for cluster in clusters:
        # skip single nodes as these cannot be converted with create_mesh
        if len(cluster) == 1:
            continue
        g = graph.copy()
        for c in clusters:
            if not c == cluster:
                g.remove_nodes_from(c)

        # Convert to mesh (this takes a while)
        bm = create_mesh(g, radius=radius)

        # Store
        meshes.append(bm)

    # Define callback functions
    def meshEnterEvent(event):
        event.owner.faceColor = selectColor

    def meshLeaveEvent(event):
        event.owner.faceColor = faceColor

    def figureKeyEvent(event):
        if event.key == vv.KEY_DELETE:
            m = event.owner.underMouse
            if hasattr(m, 'faceColor'):
                m.Destroy()
                graph.remove_nodes_from(clusters[m.index])

    # Visualize
    a = vv.gca()
    fig = a.GetFigure()
    for i, bm in enumerate(meshes):
        m = vv.mesh(bm)
        m.faceColor = faceColor
        m.eventEnter.Bind(meshEnterEvent)
        m.eventLeave.Bind(meshLeaveEvent)
        m.hitTest = True
        m.index = i
    # Bind event handlers to figure
    fig.eventKeyDown.Bind(figureKeyEvent)
    a.SetLimits()
    a.bgcolor = 'k'
    a.axis.axisColor = 'w'
    a.axis.visible = axVis
    a.daspect = 1, 1, -1

    # Prevent the callback functions from going out of scope
    a._callbacks = meshEnterEvent, meshLeaveEvent, figureKeyEvent

    # Done return axes
    return a
Beispiel #8
0
    # and reliably calculate volume.
    filename = '{}_{}_neck.stl'.format(ptcode, ctcode1)
    vesselMesh = loadmesh(basedirMesh, ptcode[-3:], filename)  #inverts Z
    vv.processing.unwindFaces(vesselMesh)
    vesselMesh = meshlib.Mesh(vesselMesh._vertices)
    vesselMesh.ensure_closed()
    ppvessel = PointSet(vesselMesh.get_flat_vertices())  # Must be flat!

# Load ring model
try:
    modelmesh1
except NameError:
    s1 = loadmodel(basedir, ptcode, ctcode1, cropname, modelname)
    if drawRingMesh:
        if not ringMeshDisplacement:
            modelmesh1 = create_mesh(s1.model, 0.7)  # Param is thickness
        else:
            modelmesh1 = create_mesh_with_abs_displacement(s1.model,
                                                           radius=0.7,
                                                           dim=dimensions)

# Load vessel centerline (excel terarecon) (is very fast)
centerline = PointSet(
    np.column_stack(
        load_excel_centerline(basedirCenterline,
                              vol1,
                              ptcode,
                              ctcode1,
                              filename=None)))

## Setup visualization
Beispiel #9
0
    def __init__(self,
                 ptcode,
                 ctcode,
                 basedir,
                 showVol='mip',
                 meshWithColors=False,
                 motion='amplitude',
                 clim2=(0, 2)):
        """
        Script to show the stent model. [ nellix]
        """

        import os
        import pirt
        import visvis as vv

        from stentseg.utils.datahandling import select_dir, loadvol, loadmodel
        from pirt.utils.deformvis import DeformableTexture3D, DeformableMesh
        from stentseg.stentdirect.stentgraph import create_mesh
        from stentseg.stentdirect import stentgraph
        from stentseg.utils.visualization import show_ctvolume
        from stentseg.motion.vis import create_mesh_with_abs_displacement
        import copy
        from stentseg.motion.dynamic import incorporate_motion_nodes, incorporate_motion_edges
        from lspeas.utils.ecgslider import runEcgSlider
        from stentseg.utils import _utils_GUI

        import numpy as np

        cropname = 'prox'
        # params
        nr = 1
        # motion = 'amplitude'  # amplitude or sum
        dimension = 'xyz'
        showVol = showVol  # MIP or ISO or 2D or None
        clim0 = (0, 2000)
        # clim2 = (0,2)
        motionPlay = 9, 1  # each x ms, a step of x %

        s = loadvol(basedir, ptcode, ctcode, cropname, what='deforms')
        m = loadmodel(basedir, ptcode, ctcode, cropname,
                      'centerline_total_modelavgreg_deforms')
        v = loadmodel(basedir,
                      ptcode,
                      ctcode,
                      cropname,
                      modelname='centerline_total_modelvesselavgreg_deforms')
        s2 = loadvol(basedir, ptcode, ctcode, cropname, what='avgreg')
        vol_org = copy.deepcopy(s2.vol)
        s2.vol.sampling = [
            vol_org.sampling[1], vol_org.sampling[1], vol_org.sampling[2]
        ]
        s2.sampling = s2.vol.sampling
        vol = s2.vol

        # merge models into one for dynamic visualization
        model_total = stentgraph.StentGraph()
        for key in dir(m):
            if key.startswith('model'):
                model_total.add_nodes_from(
                    m[key].nodes(data=True))  # also attributes
                model_total.add_edges_from(m[key].edges(data=True))
        for key in dir(v):
            if key.startswith('model'):
                model_total.add_nodes_from(
                    v[key].nodes(data=True))  # also attributes
                model_total.add_edges_from(v[key].edges(data=True))

        # Load deformations (forward for mesh)
        deformkeys = []
        for key in dir(s):
            if key.startswith('deform'):
                deformkeys.append(key)
        deforms = [s[key] for key in deformkeys]
        deforms = [[field[::2, ::2, ::2] for field in fields]
                   for fields in deforms]

        # These deforms are forward mapping. Turn into DeformationFields.
        # Also get the backwards mapping variants (i.e. the inverse deforms).
        # The forward mapping deforms should be used to deform meshes (since
        # the information is used to displace vertices). The backward mapping
        # deforms should be used to deform textures (since they are used in
        # interpolating the texture data).
        deforms_f = [pirt.DeformationFieldForward(*f) for f in deforms]
        deforms_b = [f.as_backward() for f in deforms_f]

        # Create mesh
        if meshWithColors:
            try:
                modelmesh = create_mesh_with_abs_displacement(model_total,
                                                              radius=0.7,
                                                              dim=dimension,
                                                              motion=motion)
            except KeyError:
                print('Centerline model has no pathdeforms so we create them')
                # use unsampled deforms
                deforms2 = [s[key] for key in deformkeys]
                # deforms as backward for model
                deformsB = [
                    pirt.DeformationFieldBackward(*fields)
                    for fields in deforms2
                ]
                # set sampling to original
                # for i in range(len(deformsB)):
                #         deformsB[i]._field_sampling = tuple(s.sampling)
                # not needed because we use unsampled deforms
                # Combine ...
                incorporate_motion_nodes(model_total, deformsB, s2.origin)
                convert_paths_to_PointSet(model_total)
                incorporate_motion_edges(model_total, deformsB, s2.origin)
                convert_paths_to_ndarray(model_total)
                modelmesh = create_mesh_with_abs_displacement(model_total,
                                                              radius=0.7,
                                                              dim=dimension,
                                                              motion=motion)
        else:
            modelmesh = create_mesh(model_total, 0.7, fullPaths=True)

        ## Start vis
        f = vv.figure(nr)
        vv.clf()
        if nr == 1:
            f.position = 8.00, 30.00, 944.00, 1002.00
        else:
            f.position = 968.00, 30.00, 944.00, 1002.00
        a = vv.gca()
        a.axis.axisColor = 1, 1, 1
        a.axis.visible = False
        a.bgcolor = 0, 0, 0
        a.daspect = 1, 1, -1
        t = vv.volshow(vol, clim=clim0, renderStyle=showVol, axes=a)
        vv.xlabel('x (mm)')
        vv.ylabel('y (mm)')
        vv.zlabel('z (mm)')
        if meshWithColors:
            if dimension == 'xyz':
                dim = '3D'
            vv.title(
                'Model for chEVAS %s  (color-coded %s of movement in %s in mm)'
                % (ptcode[8:], motion, dim))
        else:
            vv.title('Model for chEVAS %s' % (ptcode[8:]))

        colorbar = True
        # Create deformable mesh
        dm = DeformableMesh(a, modelmesh)  # in x,y,z
        dm.SetDeforms(*[list(reversed(deform))
                        for deform in deforms_f])  # from z,y,x to x,y,z
        if meshWithColors:
            dm.clim = clim2
            dm.colormap = vv.CM_JET  #todo: use colormap Viridis or Magma as JET is not linear (https://bids.github.io/colormap/)
            if colorbar:
                vv.colorbar()
            colorbar = False
        else:
            dm.faceColor = 'g'

        # Run mesh
        a.SetLimits()
        # a.SetView(viewringcrop)
        dm.MotionPlay(motionPlay[0],
                      motionPlay[1])  # (10, 0.2) = each 10 ms do a step of 20%
        dm.motionSplineType = 'B-spline'
        dm.motionAmplitude = 0.5

        ## run ecgslider
        ecg = runEcgSlider(dm, f, a, motionPlay)
Beispiel #10
0
def interactiveCenterlineID(s,
                            ptcode,
                            ctcode,
                            basedir,
                            cropname,
                            modelname,
                            radius=0.7,
                            axVis=True,
                            faceColor=(0.5, 1.0, 0.3),
                            selectColor=(1.0, 0.3, 0.3)):
    """ showGraphAsMesh(graph, radius=0.7, 
                faceColor=(0.5,1.0,0.3), selectColor=(1.0,0.3, 0.3) )
    
    Manual identidy centerlines; s contains models. 
    Show the given graphs as a mesh, or to be more precize as a set of meshes 
    representing the centerlines in the struct. 
    By holding the mouse over a mesh, it can be selected, after which it can be 
    identified by pressing enter.
    
    Returns the axes in which the meshes are drawn and s2 with new named models.
    
    """
    import visvis as vv
    import networkx as nx
    from stentseg.stentdirect import stentgraph
    from stentseg.stentdirect.stentgraph import create_mesh
    # from nellix._select_centerline_points import _Select_Centerline_Points
    import copy
    import numpy as np

    print("Move mouse over centerlines and press ENTER to identify")
    print(
        "Give either NelL, NelR, LRA, RRA, SMA for stents or vLRA, vRRA, vSMA for transition vessel-stent"
    )
    print("Press ESCAPE to save ssdf and finish")
    # Get clusters of nodes from each centerline
    clusters = []
    meshes = []
    s2 = copy.copy(s)
    for key in s:
        if key.startswith('model'):
            clusters.append(s[key])
            del s2[key]
            # Convert to mesh (this takes a while)
            bm = create_mesh(s[key], radius=radius, fullPaths=False)
            # Store
            meshes.append(bm)

    ppallcenterlines = []
    ppendsall = []
    for key2 in s:
        if key2.startswith('ppC'):  # 'ppCenterline1'  2 ..
            ppallcenterlines.append(s[key2])
            ppendsall.append(np.array(
                (s[key2][0, :], s[key2][-1, :])))  # first and last point cll
            # now remove from ssdf
            del s2[key2]
    ppallcenterlines = np.asarray(ppallcenterlines)

    centerlines = [None] * len(clusters)

    # Define callback functions
    def meshEnterEvent(event):
        event.owner.faceColor = selectColor

    def meshLeaveEvent(event):
        if event.owner.hitTest:  # True
            event.owner.faceColor = faceColor
        else:
            event.owner.faceColor = 'b'

    def figureKeyEvent(event):
        if event.key == vv.KEY_ENTER:
            m = event.owner.underMouse
            if hasattr(m, 'faceColor'):
                m.faceColor = 'y'
                dialog_output = get_index_name()
                name = dialog_output
                model = clusters[m.index]
                # get pp corresponding to model
                ends = []
                for n in sorted(model.nodes()):
                    if model.degree(n) == 1:
                        ends.append(n)
                if len(ends) > 2:
                    raise RuntimeError(
                        'Centerline has more than 2 nodes with 1 neighbour')
                ends = np.asarray(ends, dtype='float32')
                # add selected cll to s2
                if name in [
                        'NelL', 'NelR', 'LRA', 'RRA', 'SMA', 'vLRA', 'vRRA',
                        'vSMA'
                ]:
                    s2['model' + name] = model
                    for i, ppend in enumerate(
                            ppendsall):  # pp for each centerline
                        if ends[0] in ppend:  # should not matter which end
                            print('ppCenterline was added to s2')
                            s2['ppCenterline' + name] = ppallcenterlines[i]
                    m.hitTest = False
                else:
                    print(
                        "Name entered not known, give either NelL, NelR, LRA, RRA, SMA, 'vLRA', 'vRRA', 'vSMA'"
                    )
        if event.key == vv.KEY_ESCAPE:
            # Save ssdf
            filename = '%s_%s_%s_%s.ssdf' % (ptcode, ctcode, cropname,
                                             modelname + '_id')
            s3 = copy.deepcopy(s2)  # do not change s2
            for key in s3:
                if key.startswith('model'):
                    s3[key] = s3[key].pack()
            vv.ssdf.save(os.path.join(basedir, ptcode, filename), s3)
            print("Finished, ssdf {} saved to disk".format(filename))
            # fig.Destroy() # warning?

    # Visualize
    a = vv.gca()
    fig = a.GetFigure()
    for i, bm in enumerate(meshes):
        m = vv.mesh(bm)
        m.faceColor = faceColor
        m.eventEnter.Bind(meshEnterEvent)
        m.eventLeave.Bind(meshLeaveEvent)
        m.hitTest = True
        m.index = i
    # Bind event handlers to figure
    fig.eventKeyDown.Bind(figureKeyEvent)
    a.SetLimits()
    a.bgcolor = 'k'
    a.axis.axisColor = 'w'
    a.axis.visible = axVis
    a.daspect = 1, 1, -1

    # Prevent the callback functions from going out of scope
    a._callbacks = meshEnterEvent, meshLeaveEvent, figureKeyEvent

    # Done return axes and s2 with new named centerlines
    return a, s2
Beispiel #11
0
s = loadmodel(basedir, ptcode, ctcode, cropname, modelname)
if len(ringnames)==1: # show entire model or 1 ring
    model = s[ringnames[0]]
else:
    # merge ring models into one graph for dynamic visualization
    model = stentgraph.StentGraph()
    for key in ringnames:
        model.add_nodes_from(s[key].nodes(data=True)) # also attributes
        model.add_edges_from(s[key].edges(data=True))

if meshWithColors=='displacement':
    modelmesh = create_mesh_with_abs_displacement(model, radius = 0.7, dim = dimension, motion = motion)
elif meshWithColors=='curvature':
    modelmesh = create_mesh_with_values(model, valueskey='path_curvature_change', radius=0.7)
else:
    modelmesh = create_mesh(model, 1.0)  # Param is thickness

# Load static CT image to add as reference
try:
    s2 = loadvol(basedir, ptcode, ctcode, 'stent', staticref)
except FileNotFoundError:
    s2 = loadvol(basedir, ptcode, ctcode, 'ring', staticref)
vol = s2.vol


## Start vis
f = vv.figure(nr); vv.clf()
if nr == 1:
    f.position = 8.00, 30.00,  1216.00, 960.00
else:
    f.position = 968.00, 30.00,  1216.00, 960.00