Example #1
0
class ZoomOverlay(AbstractOverlay):
    '''
    
    Adapted from a Chaco example.
    '''

    source = Instance(BaseXYPlot)
    destination = Instance(Component)

    border_color = ColorTrait((0, 0, 0.7, 1))
    border_width = Int(1)
    fill_color = ColorTrait("lightblue")
    alpha = Float(0.3)

    traits_view = View(Group(Item('fill_color', label="Color", style="simple"),
                             Item('border_width',
                                  label="Border Width",
                                  style="custom"),
                             Item('border_color', label="Border Color"),
                             orientation="vertical"),
                       width=500,
                       height=300,
                       resizable=True,
                       title="Configure Settings",
                       buttons=['OK', 'Cancel'])

    #*************************************calculate_points()*************************************
    def calculate_points(self, component):
        '''
        
        Args: 
        Returns:
        Raises:
        '''
        # find selection range on source plot
        x_start, x_end = self._get_selection_screencoords()
        if x_start > x_end:
            x_start, x_end = x_end, x_start

        y_end = self.source.y
        y_start = self.source.y2

        left_top = array([x_start, y_end])
        left_mid = array([x_start, y_start])
        right_top = array([x_end, y_end])
        right_mid = array([x_end, y_start])

        # Offset y because we want to avoid overlapping the trapezoid with the topmost
        # pixels of the destination plot.
        #y = self.destination.y + 1
        y = 100

        left_end = array([self.destination.x, y])
        right_end = array([self.source.x2, y])

        polygon = array(
            (left_top, left_mid, left_end, right_end, right_mid, right_top))
        left_line = array((left_top, left_mid, left_end))
        right_line = array((right_end, right_mid, right_top))

        return left_line, right_line, polygon

    #*************************************overlay()*************************************
    def overlay(self, component, gc, view_bounds=None, mode="normal"):
        '''
        
        Args: 
        Returns:
        Raises:
        '''

        tmp = self._get_selection_screencoords()
        if tmp is None:
            return

        left_line, right_line, polygon = self.calculate_points(component)

        with gc:
            gc.translate_ctm(*component.position)
            gc.set_alpha(self.alpha)
            gc.set_fill_color(self.fill_color_)
            gc.set_line_width(self.border_width)
            gc.set_stroke_color(self.border_color_)
            gc.begin_path()
            gc.lines(polygon)
            gc.fill_path()

            gc.begin_path()
            gc.lines(left_line)
            gc.lines(right_line)
            gc.stroke_path()

        return

    #*************************************_get_selection_screencords()*************************************
    def _get_selection_screencoords(self):
        '''
        
        Args: 
        Returns:
        Raises:
        '''
        selection = self.source.index.metadata["selections"]
        if selection is not None and len(selection) == 2:
            mapper = self.source.index_mapper
            return mapper.map_screen(array(selection))
        else:
            return None

    #*************************************_source_changed()*************************************
    def _source_changed(self, old, new):
        '''
        
        Args: 
        Returns:
        Raises:
        '''
        if old is not None and old.controller is not None:
            old.controller.on_trait_change(self._selection_update_handler,
                                           "selection",
                                           remove=True)
        if new is not None and new.controller is not None:
            new.controller.on_trait_change(self._selection_update_handler,
                                           "selection")
        return

    #*************************************_selection_update_handler()*************************************
    def _selection_update_handler(self, value):
        '''
        
        Args: 
        Returns:
        Raises:
        '''
        if value is not None and self.destination is not None:
            r = self.destination.index_mapper.range
            start, end = amin(value), amax(value)
            r.low = start
            r.high = end

        self.source.request_redraw()
        self.destination.request_redraw()
        return
Example #2
0
class TriangleWave(HasTraits):
    # 指定三角波的最窄和最宽范围,由于Range类型不能将常数和Traits属性名混用
    # 所以定义这两个值不变的Trait属性
    low = Float(0.02)
    hi = Float(1.0)

    # 三角波形的宽度
    wave_width = Range("low", "hi", 0.5)

    # 三角波的顶点C的x轴坐标
    length_c = Range("low", "wave_width", 0.5)

    # 三角波的定点的y轴坐标
    height_c = Float(1.0)

    # FFT计算所使用的取样点数,这里用一个Enum类型的属性以供用户从列表中选择
    fftsize = Enum([(2**x) for x in range(6, 12)])

    # FFT频谱图的x轴上限值
    fft_graph_up_limit = Range(0, 400, 20)

    # 用于显示FFT的结果
    peak_list = Str

    # 采用多少个频率合成三角波
    N = Range(1, 40, 4)

    # 保存绘图数据的对象
    plot_data = Instance(AbstractPlotData)

    # 绘制波形图的容器
    plot_wave = Instance(Component)

    # 绘制FFT频谱图的容器
    plot_fft = Instance(Component)

    # 包括两个绘图的容器
    container = Instance(Component)

    # 设置用户界面的视图, 注意一定要指定窗口的大小,这样绘图容器才能正常初始化
    view = View(HSplit(
        VSplit(
            VGroup(Item("wave_width", editor=scrubber, label="波形宽度"),
                   Item("length_c", editor=scrubber, label="最高点x坐标"),
                   Item("height_c", editor=scrubber, label="最高点y坐标"),
                   Item("fft_graph_up_limit", editor=scrubber, label="频谱图范围"),
                   Item("fftsize", label="FFT点数"), Item("N", label="合成波频率数")),
            Item("peak_list",
                 style="custom",
                 show_label=False,
                 width=100,
                 height=250)),
        VGroup(Item("container",
                    editor=ComponentEditor(size=(600, 300)),
                    show_label=False),
               orientation="vertical")),
                resizable=True,
                width=800,
                height=600,
                title="三角波FFT演示")

    # 创建绘图的辅助函数,创建波形图和频谱图有很多类似的地方,因此单独用一个函数以
    # 减少重复代码
    def _create_plot(self, data, name, type="line"):
        p = Plot(self.plot_data)
        p.plot(data, name=name, title=name, type=type)
        p.tools.append(PanTool(p))
        zoom = ZoomTool(component=p, tool_mode="box", always_on=False)
        p.overlays.append(zoom)
        p.title = name
        return p

    def __init__(self):
        # 首先需要调用父类的初始化函数
        super(TriangleWave, self).__init__()

        # 创建绘图数据集,暂时没有数据因此都赋值为空,只是创建几个名字,以供Plot引用
        self.plot_data = ArrayPlotData(x=[], y=[], f=[], p=[], x2=[], y2=[])

        # 创建一个垂直排列的绘图容器,它将频谱图和波形图上下排列
        self.container = VPlotContainer()

        # 创建波形图,波形图绘制两条曲线: 原始波形(x,y)和合成波形(x2,y2)
        self.plot_wave = self._create_plot(("x", "y"), "Triangle Wave")
        self.plot_wave.plot(("x2", "y2"), color="red")

        # 创建频谱图,使用数据集中的f和p
        self.plot_fft = self._create_plot(("f", "p"), "FFT", type="scatter")

        # 将两个绘图容器添加到垂直容器中
        self.container.add(self.plot_wave)
        self.container.add(self.plot_fft)

        # 设置
        self.plot_wave.x_axis.title = "Samples"
        self.plot_fft.x_axis.title = "Frequency pins"
        self.plot_fft.y_axis.title = "(dB)"

        # 改变fftsize为1024,因为Enum的默认缺省值为枚举列表中的第一个值
        self.fftsize = 1024

    # FFT频谱图的x轴上限值的改变事件处理函数,将最新的值赋值给频谱图的响应属性
    def _fft_graph_up_limit_changed(self):
        self.plot_fft.x_axis.mapper.range.high = self.fft_graph_up_limit

    def _N_changed(self):
        self.plot_sin_combine()

    # 多个trait属性的改变事件处理函数相同时,可以用@on_trait_change指定
    @on_trait_change("wave_width, length_c, height_c, fftsize")
    def update_plot(self):
        # 计算三角波
        global y_data
        x_data = np.arange(0, 1.0, 1.0 / self.fftsize)
        func = self.triangle_func()
        # 将func函数的返回值强制转换成float64
        y_data = np.cast["float64"](func(x_data))

        # 计算频谱
        fft_parameters = np.fft.fft(y_data) / len(y_data)

        # 计算各个频率的振幅
        fft_data = np.clip(
            20 * np.log10(np.abs(fft_parameters))[:self.fftsize / 2 + 1], -120,
            120)

        # 将计算的结果写进数据集
        self.plot_data.set_data("x", np.arange(0, self.fftsize))  # x坐标为取样点
        self.plot_data.set_data("y", y_data)
        self.plot_data.set_data("f", np.arange(0, len(fft_data)))  # x坐标为频率编号
        self.plot_data.set_data("p", fft_data)

        # 合成波的x坐标为取样点,显示2个周期
        self.plot_data.set_data("x2", np.arange(0, 2 * self.fftsize))

        # 更新频谱图x轴上限
        self._fft_graph_up_limit_changed()

        # 将振幅大于-80dB的频率输出
        peak_index = (fft_data > -80)
        peak_value = fft_data[peak_index][:20]
        result = []
        for f, v in zip(np.flatnonzero(peak_index), peak_value):
            result.append("%s : %s" % (f, v))
        self.peak_list = "\n".join(result)

        # 保存现在的fft计算结果,并计算正弦合成波
        self.fft_parameters = fft_parameters
        self.plot_sin_combine()

    # 计算正弦合成波,计算2个周期
    def plot_sin_combine(self):
        index, data = fft_combine(self.fft_parameters, self.N, 2)
        self.plot_data.set_data("y2", data)

    # 返回一个ufunc计算指定参数的三角波
    def triangle_func(self):
        c = self.wave_width
        c0 = self.length_c
        hc = self.height_c

        def trifunc(x):
            x = x - int(x)  # 三角波的周期为1,因此只取x坐标的小数部分进行计算
            if x >= c: r = 0.0
            elif x < c0: r = x / c0 * hc
            else: r = (c - x) / (c - c0) * hc
            return r

        # 用trifunc函数创建一个ufunc函数,可以直接对数组进行计算, 不过通过此函数
        # 计算得到的是一个Object数组,需要进行类型转换
        return np.frompyfunc(trifunc, 1, 1)
Example #3
0
class A(HasTraits):
    scale = Float(1.0)
    t1 = ScaledValue 
Example #4
0
class SphereSource(ProcessObject):
    radius = Float(5.0)
Example #5
0
File: custom.py Project: ghorn/Eg
class Bar(HasTraits):
	foo=Float(15)
Example #6
0
class Cluster(BaseGraph):
    """ Defines a representation of a subgraph in Graphviz's dot language.
    """

    #--------------------------------------------------------------------------
    #  Trait definitions:
    #--------------------------------------------------------------------------

#    ID = Str("cluster", desc="that clusters are encoded as subgraphs whose "
#        "names have the prefix 'cluster'")
#
#    name = Alias("ID", desc="synonym for ID")
#
#    # Nodes in the cluster.
#    nodes = List(Instance(Node))
#
#    # Edges in the cluster.
#    edges = List(Instance(Edge))
#
#    # Subgraphs of the cluster.
#    subgraphs = List(Instance("godot.subgraph.Subgraph"))
#
#    # Separate rectangular layout regions.
#    clusters = List(Instance("godot.cluster.Cluster"))

    # Parent graph in the graph heirarchy.
#    parent = Instance("godot.graph:Graph")

    # Root graph instance.
#    root = Instance("godot.graph:Graph")

    #--------------------------------------------------------------------------
    #  Xdot trait definitions:
    #--------------------------------------------------------------------------

    # For a given graph object, one will typically a draw directive before the
    # label directive. For example, for a node, one would first use the
    # commands in _draw_ followed by the commands in _ldraw_.
#    _draw_ = Str(desc="xdot drawing directive")
#
#    # Label draw directive.
#    _ldraw_ = Str(desc="xdot label drawing directive")

    #--------------------------------------------------------------------------
    #  Graphviz dot language trait definitions:
    #--------------------------------------------------------------------------

    # When attached to the root graph, this color is used as the background for
    # entire canvas. When a cluster attribute, it is used as the initial
    # background for the cluster. If a cluster has a filled
    # <html:a rel="attr">style</html:a>, the
    # cluster's <html:a rel="attr">fillcolor</html:a> will overlay the
    # background color.
    #
    # If no background color is specified for the root graph, no graphics
    # operation are performed on the background. This works fine for
    # PostScript but for bitmap output, all bits are initialized to something.
    # This means that when the bitmap output is included in some other
    # document, all of the bits within the bitmap's bounding box will be
    # set, overwriting whatever color or graphics where already on the page.
    # If this effect is not desired, and you only want to set bits explicitly
    # assigned in drawing the graph, set <html:a rel="attr">bgcolor</html:a>="transparent".
    bgcolor = Color("white", desc="color used as the background for the "
        "entire canvas", label="Background Color", graphviz=True)

    # Basic drawing color for graphics, not text. For the latter, use the
    # <html:a rel="attr">fontcolor</html:a> attribute.
    #
    # For edges, the value can either be a single
    # <html:a rel="type">color</html:a> or a
    # <html:a rel="type">colorList</html:a>.
    # In the latter case, the edge is drawn using parallel splines or lines,
    # one for each color in the list, in the order given.
    # The head arrow, if any, is drawn using the first color in the list,
    # and the tail arrow, if any, the second color. This supports the common
    # case of drawing opposing edges, but using parallel splines instead of
    # separately routed multiedges.
    color = color_trait

    # This attribute specifies a color scheme namespace. If defined, it
    # specifies the context for interpreting color names. In particular, if a
    # <html:a rel="type">color</html:a> value has form <html:code>xxx</html:code> or <html:code>//xxx</html:code>,
    # then the color <html:code>xxx</html:code> will be evaluated according to the current color scheme.
    # If no color scheme is set, the standard X11 naming is used.
    # For example, if <html:code>colorscheme=bugn9</html:code>, then <html:code>color=7</html:code>
    # is interpreted as <html:code>/bugn9/7</html:code>.
    colorscheme = color_scheme_trait

    # Color used to fill the background of a node or cluster
    # assuming <html:a rel="attr">style</html:a>=filled.
    # If <html:a rel="attr">fillcolor</html:a> is not defined, <html:a rel="attr">color</html:a> is
    # used. (For clusters, if <html:a rel="attr">color</html:a> is not defined,
    # <html:a rel="attr">bgcolor</html:a> is used.) If this is not defined,
    # the default is used, except for
    # <html:a rel="attr">shape</html:a>=point or when the output
    # format is MIF,
    # which use black by default.
    #
    # Note that a cluster inherits the root graph's attributes if defined.
    # Thus, if the root graph has defined a <html:a rel="attr">fillcolor</html:a>, this will override a
    # <html:a rel="attr">color</html:a> or <html:a rel="attr">bgcolor</html:a> attribute set for the cluster.
    fillcolor = Color("grey", desc="fill color for background of a node",
        graphviz=True)

    # If true, the node size is specified by the values of the
    # <html:a rel="attr">width</html:a>
    # and <html:a rel="attr">height</html:a> attributes only
    # and is not expanded to contain the text label.
    fixedsize = Bool(False, desc="node size to be specified by 'width' and "
        "'height'", label="Fixed size", graphviz=True)

    # Color used for text.
    fontcolor = fontcolor_trait

    # Font used for text. This very much depends on the output format and, for
    # non-bitmap output such as PostScript or SVG, the availability of the font
    # when the graph is displayed or printed. As such, it is best to rely on
    # font faces that are generally available, such as Times-Roman, Helvetica or
    # Courier.
    #
    # If Graphviz was built using the
    # <html:a href="http://pdx.freedesktop.org/~fontconfig/fontconfig-user.html">fontconfig library</html:a>, the latter library
    # will be used to search for the font. However, if the <html:a rel="attr">fontname</html:a> string
    # contains a slash character "/", it is treated as a pathname for the font
    # file, though font lookup will append the usual font suffixes.
    #
    # If Graphviz does not use fontconfig, <html:a rel="attr">fontname</html:a> will be
    # considered the name of a Type 1 or True Type font file.
    # If you specify <html:code>fontname=schlbk</html:code>, the tool will look for a
    # file named  <html:code>schlbk.ttf</html:code> or <html:code>schlbk.pfa</html:code> or <html:code>schlbk.pfb</html:code>
    # in one of the directories specified by
    # the <html:a rel="attr">fontpath</html:a> attribute.
    # The lookup does support various aliases for the common fonts.
    fontname = fontname_trait

    # Font size, in <html:a rel="note">points</html:a>, used for text.
    fontsize = fontsize_trait

    # Spring constant used in virtual physical model. It roughly corresponds to
    # an ideal edge length (in inches), in that increasing K tends to increase
    # the distance between nodes. Note that the edge attribute len can be used
    # to override this value for adjacent nodes.
    K = Float(0.3, desc="spring constant used in virtual physical model",
        graphviz=True)

    # Text label attached to objects.
    # If a node's <html:a rel="attr">shape</html:a> is record, then the label can
    # have a <html:a href="http://www.graphviz.org/doc/info/shapes.html#record">special format</html:a>
    # which describes the record layout.
    label = label_trait

    # Justification for cluster labels. If <html:span class="val">r</html:span>, the label
    # is right-justified within bounding rectangle; if <html:span class="val">l</html:span>, left-justified;
    # else the label is centered.
    # Note that a subgraph inherits attributes from its parent. Thus, if
    # the root graph sets <html:a rel="attr">labeljust</html:a> to <html:span class="val">l</html:span>, the subgraph inherits
    # this value.
    labeljust = Trait("c", {"Centre": "c", "Right": "r", "Left": "l"},
        desc="justification for cluster labels", label="Label justification",
        graphviz=True)

    # Top/bottom placement of graph and cluster labels.
    # If the attribute is <html:span class="val">t</html:span>, place label at the top;
    # if the attribute is <html:span class="val">b</html:span>, place label at the bottom.
    # By default, root
    # graph labels go on the bottom and cluster labels go on the top.
    # Note that a subgraph inherits attributes from its parent. Thus, if
    # the root graph sets <html:a rel="attr">labelloc</html:a> to <html:span class="val">b</html:span>, the subgraph inherits
    # this value.
    labelloc = Trait("b", {"Bottom": "b", "Top":"t"},
        desc="placement of graph and cluster labels",
        label="Label location", graphviz=True)

    # Label position, in points. The position indicates the center of the
    # label.
    lp = Str#point_trait

    # By default, the justification of multi-line labels is done within the
    # largest context that makes sense. Thus, in the label of a polygonal node,
    # a left-justified line will align with the left side of the node (shifted
    # by the prescribed margin). In record nodes, left-justified line will line
    # up with the left side of the enclosing column of fields. If nojustify is
    # "true", multi-line labels will be justified in the context of itself. For
    # example, if the attribute is set, the first label line is long, and the
    # second is shorter and left-justified, the second will align with the
    # left-most character in the first line, regardless of how large the node
    # might be.
    nojustify = nojustify_trait


    # Color used to draw the bounding box around a cluster. If
    # <html:a rel="attr">pencolor</html:a> is not defined,
    # <html:a rel="attr">color</html:a> is used. If this is not defined,
    # <html:a rel="attr">bgcolor</html:a> is used. If this is not defined, the
    # default is used.
    # Note that a cluster inherits the root graph's attributes if defined.
	# Thus, if the root graph has defined a
    # <html:a rel="attr">pencolor</html:a>, this will override a
	# <html:a rel="attr">color</html:a> or <html:a rel="attr">bgcolor</html:a>
    # attribute set for the cluster.
    pencolor = Color("grey", desc="color for the cluster bounding box",
        graphviz=True)

    # Set style for node or edge. For cluster subgraph, if "filled", the
    # cluster box's background is filled.
#    style = ListStr(desc="style for node")
    style = Str(desc="style for node", graphviz=True)

    # If the object has a URL, this attribute determines which window
    # of the browser is used for the URL.
    # See <html:a href="http://www.w3.org/TR/html401/present/frames.html#adef-target">W3C documentation</html:a>.
    target = target_trait

    # Tooltip annotation attached to the node or edge. If unset, Graphviz
    # will use the object's <html:a rel="attr">label</html:a> if defined.
    # Note that if the label is a record specification or an HTML-like
    # label, the resulting tooltip may be unhelpful. In this case, if
    # tooltips will be generated, the user should set a <html:tt>tooltip</html:tt>
    # attribute explicitly.
    tooltip = tooltip_trait

    # Hyperlinks incorporated into device-dependent output.
    # At present, used in ps2, cmap, i*map and svg formats.
    # For all these formats, URLs can be attached to nodes, edges and
    # clusters. URL attributes can also be attached to the root graph in ps2,
    # cmap and i*map formats. This serves as the base URL for relative URLs in the
    # former, and as the default image map file in the latter.
    #
    # For svg, cmapx and imap output, the active area for a node is its
    # visible image.
    # For example, an unfilled node with no drawn boundary will only be active on its label.
    # For other output, the active area is its bounding box.
    # The active area for a cluster is its bounding box.
    # For edges, the active areas are small circles where the edge contacts its head
    # and tail nodes. In addition, for svg, cmapx and imap, the active area
    # includes a thin polygon approximating the edge. The circles may
    # overlap the related node, and the edge URL dominates.
    # If the edge has a label, this will also be active.
    # Finally, if the edge has a head or tail label, this will also be active.
    #
    # Note that, for edges, the attributes <html:a rel="attr">headURL</html:a>,
    # <html:a rel="attr">tailURL</html:a>, <html:a rel="attr">labelURL</html:a> and
    # <html:a rel="attr">edgeURL</html:a> allow control of various parts of an
    # edge. Also note that, if active areas of two edges overlap, it is unspecified
    # which area dominates.
    URL = url_trait

    #--------------------------------------------------------------------------
    #  Views:
    #--------------------------------------------------------------------------

    traits_view = View(
        Tabbed(
            VGroup(
                Item("ID"),
                Tabbed(nodes_item, edges_item, dock="tab"),
                label="Cluster"
            ),
            Group(["label", "fontname", "fontsize", "nojustify", "labeljust",
                   "labelloc", "lp"],
                label="Label"
            ),
            Group(["bgcolor", "color", "colorscheme", "fillcolor", "fontcolor",
                   "pencolor"],
                label="Color"),
            Group(["fixedsize", "K", "style", "tooltip", "target", "URL"],
                Group(["_draw_", "_ldraw_"], label="Xdot", show_border=True),
                label="Misc"
            )
        ),
        title="Cluster", id="godot.cluster", buttons=["OK", "Cancel", "Help"],
        resizable=True
    )

    #--------------------------------------------------------------------------
    #  "object" interface:
    #--------------------------------------------------------------------------

    def __str__(self):
        """ Returns a string representation of the cluster in dot language.
        """
        s = "subgraph"

        return "%s %s" % ( s, super(Cluster, self).__str__() )

    #--------------------------------------------------------------------------
    #  Trait initialisers:
    #--------------------------------------------------------------------------

    def _labelloc_default(self):
        """ Trait initialiser.
        """
        return "Top"

    #--------------------------------------------------------------------------
    #  Event handlers :
    #--------------------------------------------------------------------------

    def _ID_changed(self, new):
        """ Handles the ID changing by ensuring that it has a 'cluster' prefix.
        """
        if new[:7].lower() != "cluster":
            self.ID = "cluster_%s" % new
Example #7
0
class USelectorOnMeshElement(Base):
    shortName = String(ELEMENT_LENGTH)
    id = Int()
    v1 = Float()
    v2 = Float()
    v3 = Float()
Example #8
0
class Person(HasTraits):
    weight = Float(150.0)
Example #9
0
class AnimationDemo(HasTraits):
    numdots = Int(3000)
    box = Instance(DotComponent)
    points3d = Array
    target_points3d = Array
    angle_x = Float()
    angle_y = Float()
    angle_z = Float()
    frame_count = Int(0)
    view = View(
        Item("box", editor=ComponentEditor(),show_label=False),
        resizable=True,
        width = 600, 
        height = 600,
        title = u"变形",
        handler = AnimationHandler()
    )    
    
    def __init__(self, **traits):
        super(AnimationDemo, self).__init__(**traits)
        self.box = DotComponent()
        self.points3d = np.zeros((4, self.numdots))
        self.points3d[3,:] = 1
        self.target_points3d = np.zeros((4, self.numdots))
        self.target_points3d[3,:] = 1
        
        self.points3d[:3,:] = self.get_object_points()
        self.target_points3d[:3,:] = self.get_object_points()
        self.frame_count = 0
        
    def get_object_points(self):
        r = np.random.rand(self.numdots)
        t = np.linspace(0, np.pi*2, self.numdots)
        a = r * np.pi - np.pi/2
        l = [np.cos(a), np.cos(t), np.sin(a), np.sin(t), np.ones(self.numdots)]
        c = random.choice
        return c(l)*c(l), c(l)*c(l), c(l)*c(l)
            
    def make_matrix(self):
        s = np.sin([self.angle_x, self.angle_y, self.angle_z])
        c = np.cos([self.angle_x, self.angle_y, self.angle_z])
        #构造三个轴的旋转矩阵
        mx= np.eye(4)
        mx[[1,1,2,2],[1,2,1,2]] = c[2], -s[2], s[2], c[2]
        my = np.eye(4)
        my[[0,0,2,2], [0,2,0,2]] = c[1], s[1], -s[1], c[1]
        mz = np.eye(4)
        mz[[0,0,1,1], [0,1,0,1]] = c[0],-s[0],s[0],c[0]
        #进行坐标旋转
        return np.dot(mz, np.dot(my, mx))     
        
    def projection(self):
        m = self.make_matrix()
        result = np.dot(m, self.points3d)
        self.box.projection_points(result)
        
    def on_timer(self):
        self.angle_z += 0.02
        self.angle_x += 0.02
        self.projection()
        self.box.request_redraw()
        self.points3d[:3,:] += 0.04 * (self.target_points3d[:3,:] - self.points3d[:3,:])
        self.points3d[:3,:] += np.random.normal(0, 0.004, (3, self.numdots))
        self.frame_count += 1
        if self.frame_count > 300:
            self.frame_count = 0
            self.target_points3d[:3,:] = self.get_object_points()
Example #10
0
class Beam(HasTraits):
    """Class that defines the data model for a scan."""

    #The coordinate system used is the gantry fixed

    MeasurementDetails_MeasuringDevice_Model = String()
    MeasurementDetails_MeasuringDevice_Type = String()
    MeasurementDetails_MeasuringDevice_Manufacturer = String()
    MeasurementDetails_Isocenter_x = Float(numpy.NaN)
    MeasurementDetails_Isocenter_y = Float(numpy.NaN)
    MeasurementDetails_Isocenter_z = Float(numpy.NaN)
    #    MeasurementDetails_CoordinateAxes_Inplane = Enum('', 'x_neg', 'y_neg',
    #                                        'z_neg', 'x_pos', 'y_pos', 'z_pos')
    #    MeasurementDetails_CoordinateAxes_Crossplane = Enum('', 'x_neg', 'y_neg',
    #                                        'z_neg', 'x_pos', 'y_pos', 'z_pos')
    #    MeasurementDetails_CoordinateAxes_Depth = Enum('','x_neg', 'y_neg',
    #                                        'z_neg', 'x_pos', 'y_pos', 'z_pos')
    MeasurementDetails_MeasuredDateTime = String()
    MeasurementDetails_ModificationHistory = List()
    MeasurementDetails_StartPosition_x = Float(numpy.NaN)
    MeasurementDetails_StartPosition_y = Float(numpy.NaN)
    MeasurementDetails_StartPosition_z = Float(numpy.NaN)
    MeasurementDetails_StopPosition_x = Float(numpy.NaN)
    MeasurementDetails_StopPosition_y = Float(numpy.NaN)
    MeasurementDetails_StopPosition_z = Float(numpy.NaN)
    MeasurementDetails_Physicist_Name = String()
    MeasurementDetails_Physicist_Institution = String()
    MeasurementDetails_Physicist_Telephone = String()
    MeasurementDetails_Physicist_EmailAddress = String()
    MeasurementDetails_Medium = String()
    MeasurementDetails_Servo_Model = String()
    MeasurementDetails_Servo_Vendor = String()
    MeasurementDetails_Electrometer_Model = String()
    MeasurementDetails_Electrometer_Vendor = String()
    MeasurementDetails_Electrometer_Voltage = Float(numpy.NaN)

    BeamDetails_Energy = Float(numpy.NaN)
    BeamDetails_Particle = String
    BeamDetails_SAD = Float(numpy.NaN)
    BeamDetails_SSD = Float(numpy.NaN)
    BeamDetails_CollimatorAngle = Float(numpy.NaN)
    BeamDetails_GantryAngle = Float(numpy.NaN)
    BeamDetails_CrossplaneJawPositions_NegativeJaw = Float(numpy.NaN)
    BeamDetails_CrossplaneJawPositions_PositiveJaw = Float(numpy.NaN)
    BeamDetails_InplaneJawPositions_NegativeJaw = Float(numpy.NaN)
    BeamDetails_InplaneJawPositions_PositiveJaw = Float(numpy.NaN)
    BeamDetails_Wedge_Angle = Float(numpy.NaN)
    BeamDetails_Wedge_Type = String()
    BeamDetails_Applicator = String()
    BeamDetails_Accessory = String()
    BeamDetails_RadiationDevice_Vendor = String()
    BeamDetails_RadiationDevice_Model = String()
    BeamDetails_RadiationDevice_SerialNumber = String()
    BeamDetails_RadiationDevice_MachineScale = Enum('', 'IEC 1217',
                                                    'Varian IEC')

    Data_Abscissa = Array()
    Data_Ordinate = Array()
    Data_Quantity = String()

    #Traits that are not part of the XML data structure
    label = String()
    field_size = String()
    scan_type = String()
    depth = Float()

    traits_view = View(
        Tabbed(
            Group(
                Group(Heading('Beam Parameters'),
                      Item(name='BeamDetails_Energy',
                           label='Energy (MV)',
                           format_func=format_func),
                      Item(name='BeamDetails_Particle', label='Particle'),
                      Item(name='BeamDetails_SAD',
                           label='SAD (cm)',
                           format_func=format_func),
                      Item(name='BeamDetails_SSD',
                           label='SSD (cm)',
                           format_func=format_func),
                      Item(name='BeamDetails_CollimatorAngle',
                           label='Collimator Angle',
                           format_func=format_func),
                      Item(name='BeamDetails_GantryAngle',
                           label='Gantry Angle',
                           format_func=format_func),
                      show_border=True),
                Group(
                    Heading('Jaw Positions'),
                    Item(name='BeamDetails_CrossplaneJawPositions_NegativeJaw',
                         label='Crossplane Negative Jaw (cm)',
                         format_func=format_func),
                    Item(name='BeamDetails_CrossplaneJawPositions_PositiveJaw',
                         label='Crossplane Positive Jaw (cm)',
                         format_func=format_func),
                    Item(name='BeamDetails_InplaneJawPositions_NegativeJaw',
                         label='Inplane Negative Jaw (cm)',
                         format_func=format_func),
                    Item(name='BeamDetails_InplaneJawPositions_PositiveJaw',
                         label='Inplane Positive Jaw (cm)',
                         format_func=format_func),
                    show_border=True),
                Group(Heading('Accessories'),
                      Item(name='BeamDetails_Wedge_Type', label='Wedge Type'),
                      Item(name='BeamDetails_Wedge_Angle',
                           label='Wedge Angle',
                           format_func=format_func),
                      Item(name='BeamDetails_Applicator', label='Applicator'),
                      Item(name='BeamDetails_Accessory', label='Accessory'),
                      show_border=True),
                Group(
                    Heading('Radiation Device'),
                    Item(name='BeamDetails_RadiationDevice_Vendor',
                         label='Vendor'),
                    Item(name='BeamDetails_RadiationDevice_Model',
                         label='Model'),
                    Item(name='BeamDetails_RadiationDevice_SerialNumber',
                         label='Serial Number'),
                    Item(name='BeamDetails_RadiationDevice_MachineScale',
                         label='Machine Scale')),
                label='Beam Details',
                orientation='horizontal'),
            Group(
                Group(Group(
                    Heading('Measuring Device'),
                    Item(name='MeasurementDetails_MeasuringDevice_Model',
                         label='Model'),
                    Item(name='MeasurementDetails_MeasuringDevice_Type',
                         label='Type'),
                    Item(
                        name='MeasurementDetails_MeasuringDevice_Manufacturer',
                        label='Manufacturer'),
                    show_border=True),
                      Group(Heading('Servo'),
                            Item(name='MeasurementDetails_Servo_Model',
                                 label='Model'),
                            Item(name='MeasurementDetails_Servo_Vendor',
                                 label='Vendor'),
                            show_border=True),
                      Group(Heading('Electrometer'),
                            Item(name='MeasurementDetails_Electrometer_Model',
                                 label='Model'),
                            Item(name='MeasurementDetails_Electrometer_Vendor',
                                 label='Vendor'),
                            Item(
                                name='MeasurementDetails_Electrometer_Voltage',
                                label='Voltage (V)'),
                            show_border=True),
                      Item(name='MeasurementDetails_Medium', label='Medium'),
                      orientation='vertical'),
                Group(  #Group(Heading('Coordinate Axes'),
                    #                              Item(name='MeasurementDetails_CoordinateAxes_Inplane',
                    #                                   label='Inplane'),
                    #                              Item(name='MeasurementDetails_CoordinateAxes_Crossplane',
                    #                                   label='Crossplane'),
                    #                              Item(name='MeasurementDetails_CoordinateAxes_Depth',
                    #                                   label='Depth'),
                    #                              show_border=True),
                    Group(Heading('Isocenter'),
                          Item(name='MeasurementDetails_Isocenter_x',
                               label='x',
                               format_func=format_func),
                          Item(name='MeasurementDetails_Isocenter_y',
                               label='y',
                               format_func=format_func),
                          Item(name='MeasurementDetails_Isocenter_z',
                               label='z',
                               format_func=format_func),
                          show_border=True),
                    Group(Heading('Start Position'),
                          Item(name='MeasurementDetails_StartPosition_x',
                               label='x',
                               format_func=format_func),
                          Item(name='MeasurementDetails_StartPosition_y',
                               label='y',
                               format_func=format_func),
                          Item(name='MeasurementDetails_StartPosition_z',
                               label='z',
                               format_func=format_func),
                          show_border=True),
                    Group(Heading('Stop Position'),
                          Item(name='MeasurementDetails_StopPosition_x',
                               label='x',
                               format_func=format_func),
                          Item(name='MeasurementDetails_StopPosition_y',
                               label='y',
                               format_func=format_func),
                          Item(name='MeasurementDetails_StopPosition_z',
                               label='z',
                               format_func=format_func),
                          show_border=True),
                    orientation='vertical'),
                Group(Group(
                    Heading('Physicist'),
                    Item(name='MeasurementDetails_Physicist_Name',
                         label='Name'),
                    Item(name='MeasurementDetails_Physicist_Institution',
                         label='Institution'),
                    Item(name='MeasurementDetails_Physicist_Telephone',
                         label='Telephone'),
                    Item(name='MeasurementDetails_Physicist_EmailAddress',
                         label='Email Address'),
                    show_border=True),
                      Item(name='MeasurementDetails_MeasuredDateTime',
                           label='Date Measured'),
                      Item(name='MeasurementDetails_ModificationHistory',
                           label='Modification History'),
                      orientation='vertical'),
                orientation='horizontal',
                label='Measurement Details')),
        buttons=['Undo', 'OK', 'Cancel'],
        resizable=True,
        kind='modal')

    def __init__(self):
        super(Beam, self).__init__()
        xmltree = 'radpy/plugins/BeamAnalysis/BDML/bdml.xml'
        #xmltree = 'i:/radpy/src/radpy/plugins/BeamAnalysis/BDML/bdml.xml'
        file = open(xmltree, 'r')
        self.tree = objectify.parse(file)
        file.close()
        self.beam = self.tree.getroot()

#        schema_file = open('radpy/plugins/BeamAnalysis/BDML/bdml.xsd','r')
#        bdml_schema = etree.parse(schema_file)
#        self.xmlschema = etree.XMLSchema(bdml_schema)
#        schema_file.close()

    def does_it_match(self, args):
        """Given a dict with beam parameters, returns True if it has those."""
        #Each type of beam object must be a subclass of Beam, and must
        #implement this method so that RadPy can determine if the object
        #can provide data that matches a certain set of beam parameters.
        #For example, a 1D crossplane profile would only return True if
        #the given depth matches the depth of measurement.  However, a 3D Dicom
        #dose dataset would match any depth as long as the other beam
        #parameters (energy, field size, etc.) match.
        #The dictionary keys are the names of traits of the beam object, and
        #the values are the values that trait must match.
        #Note that it is entirely up to the Beam object to determine if there
        #is a match.  RadPy will accept any data as long as the Beam object
        #claims to match.
        raise NotImplementedError

    @on_trait_change('BeamDetails_Energy', 'BeamDetails_RadiationDevice_Model',
                     'BeamDetails_RadiationDevice_SerialNumber',
                     'BeamDetails_CrossplaneJawPositions_PositiveJaw',
                     'BeamDetails_CrossplaneJawPositions_NegativeJaw',
                     'BeamDetails_InplaneJawPositions_PositiveJaw',
                     'BeamDetails_InplaneJawPositions_NegativeJaw',
                     'MeasurementDetails_StartPosition_x',
                     'MeasurementDetails_StopPosition_x',
                     'MeasurementDetails_StartPosition_y',
                     'MeasurementDetails_StopPosition_y',
                     'MeasurementDetails_StartPosition_z',
                     'MeasurementDetails_StopPosition_z')
    def set_label(self):
        profile_type, depth = self.get_scan_descriptor()
        #self.label = '|'.join([self.get_tree_path(), profile_type, depth])
        self.label = self.get_tree_path()

    #If the isocenter depth coordinate is changed, recalculate SSD.
    @on_trait_change('MeasurementDetails_Isocenter_z')
    def recalc_SSD(self):
        self.BeamDetails_SSD = self.BeamDetails_SAD + self.MeasurementDetails_Isocenter_z

    def is_null(self, key):
        """Test to see if a given key has a null value (NAN or '')"""
        trait = self.trait(key)

        if trait.is_trait_type(String) and getattr(self, key) == '':
            return True
        elif trait.is_trait_type(Float) and numpy.isnan(getattr(self, key)):
            return True
        else:
            return False

    def get_field_size(self):
        """Return a string with field size information"""

        inplane = self.get_collimator("inplane")
        crossplane = self.get_collimator("crossplane")
        return str(inplane) + 'x' + str(crossplane)

    def get_machine(self):
        """Return a string with the machine beam data is from"""

        return self.BeamDetails_RadiationDevice_Model + ' ' +\
                self.BeamDetails_RadiationDevice_SerialNumber

    def get_energy(self):
        """Return a string with the energy and particle type"""
        #Returns a string with the usual energy/particle specification,
        #e.g. 6X, 18E.

        energy = '%g' % self.BeamDetails_Energy
        if self.BeamDetails_Particle.lower() == 'photon':
            particle = 'X'
        elif self.BeamDetails_Particle.lower() == 'electron':
            particle = 'E'
        else:
            particle = ''
        return energy + particle

    def get_accessory(self):

        if self.BeamDetails_Wedge_Type in ['', 'Open']:
            return 'Open'
        else:
            return self.BeamDetails_Wedge_Type + '_' + str(
                self.BeamDetails_Wedge_Angle)

    def get_tree_path(self):
        """Returns a string with parameters used to populate the GUI tree"""
        #Separator is |.  Uses machine name, energy and field size to tell
        #the GUI tree view where on the tree this beam belongs.

        #        return '|'.join([self.get_machine(), self.get_energy(),
        #                         self.get_accessory(), self.get_field_size()])
        temp = '|'.join([
            self.get_machine(),
            self.get_energy(),
            self.get_accessory(),
            self.get_field_size(),
            self.get_scan_type(),
            self.get_scan_depth()
        ])
        return temp

    def get_scan_type(self):
        """Determine the type of scan by comparing start and end positions"""
        try:
            if isinstance(self.Data, RTDose):
                return 'Dicom 3D Dose'
        except:
            scan_range = [self.MeasurementDetails_StartPosition_x - \
                            self.MeasurementDetails_StopPosition_x,
                          self.MeasurementDetails_StartPosition_y - \
                            self.MeasurementDetails_StopPosition_y,
                          self.MeasurementDetails_StartPosition_z - \
                            self.MeasurementDetails_StopPosition_z]
            scan_types = [
                "Crossplane Profile", "Inplane Profile", "Depth Dose"
            ]
            if scan_range.count(0.0) != 2:
                return "Point to Point"
            else:
                return scan_types[[i for i, j in enumerate(scan_range) \
                                   if j !=0][0]]

    def get_scan_depth(self):

        scan_type = self.get_scan_type()
        if scan_type == "Crossplane Profile":
            return str(-self.MeasurementDetails_StopPosition_z)
        elif scan_type == "Inplane Profile":
            return str(-self.MeasurementDetails_StopPosition_z)
        else:
            return "-"

    def get_scan_descriptor(self):
        """Return a string with scan type and position information"""

        scan_type = self.get_scan_type()
        if scan_type == "Crossplane Profile":
            return ("Crossplane", str(-self.MeasurementDetails_StopPosition_z))
        elif scan_type == "Inplane Profile":
            return ("Inplane", str(-self.MeasurementDetails_StopPosition_z))
        else:
            return ("Depth_Dose", "")

    def get_collimator(self, direction="crossplane"):

        if direction == "crossplane":
            return '%g' % (
                self.BeamDetails_CrossplaneJawPositions_PositiveJaw +
                self.BeamDetails_CrossplaneJawPositions_NegativeJaw)
        elif direction == "inplane":
            return '%g' % (self.BeamDetails_InplaneJawPositions_PositiveJaw +
                           self.BeamDetails_InplaneJawPositions_NegativeJaw)
        else:
            pass

    def get_equiv_square(self):

        x = self.get_collimator("crossplane")
        y = self.get_collimator("inplane")
        return 4 * x * y / (2 * x + 2 * y)

    def importXML(self, xml_tree):

        self.tree = etree.ElementTree(xml_tree)
        self.beam = self.tree.getroot()

        for trait, xml in TRAITS_TO_XML:
            try:
                exec('self.' + xml)
            except AttributeError:
                path = objectify.ObjectPath(xml.replace('beam.', '.'))
                path.setattr(self.beam, '')

        try:
            self.beam.MeasurementDetails.ModificationHistory
        except AttributeError:
            path = objectify.ObjectPath(
                '.MeasurementDetails.ModificationHistory.Record')
            path.setattr(self.beam, '')
        self.initialize_traits()
        #        a = self.list_traits()
        #

        abscissa = []
        ordinate = []
        mod_history = []
        for i in self.beam.Data.Abscissa.iterchildren():
            abscissa.append(float(i.text))
        for i in self.beam.Data.Ordinate.iterchildren():
            ordinate.append(float(i.text))
        for i in self.beam.MeasurementDetails.ModificationHistory.iterchildren(
        ):
            mod_history.append(str(i.text))
        self.Data_Abscissa = numpy.array(abscissa)
        self.Data_Ordinate = numpy.array(ordinate)
        self.Data_Quantity = str(self.beam.Data.Quantity)
        self.MeasurementDetails_ModificationHistory = mod_history

    def exportXML(self):

        for trait, xml in TRAITS_TO_XML:

            #exec('value = self.' + trait)
            value = getattr(self, trait)
            if value and not self.is_null(trait):
                #exec('self.' + xml + ' = value')
                ext_setattr(self, xml, value)

        self.beam.Data.Abscissa.clear()
        for i in self.Data_Abscissa:
            value = etree.SubElement(self.beam.Data.Abscissa,
                                     "{http://www.radpy.org}Value")
            #value._setText(str(i))
            self.beam.Data.Abscissa.Value[-1] = i

        self.beam.Data.Ordinate.clear()
        for i in self.Data_Ordinate:
            value = etree.SubElement(self.beam.Data.Ordinate,
                                     "{http://www.radpy.org}Value")
            #value._setText(str(i))
            self.beam.Data.Ordinate.Value[-1] = i

        self.beam.Data.Quantity = self.Data_Quantity

        self.beam.MeasurementDetails.ModificationHistory.clear()
        for i in self.MeasurementDetails_ModificationHistory:
            value = etree.SubElement(
                self.beam.MeasurementDetails.ModificationHistory,
                "{http://www.radpy.org}Record")
            self.beam.MeasurementDetails.ModificationHistory[-1] = i

        objectify.deannotate(self.tree)
        etree.cleanup_namespaces(self.tree)
        return self.beam

    def initialize_traits(self):

        for trait, xml in TRAITS_TO_XML:

            value = reduce(getattr, xml.split('.'), self)
            test = value.text
            if test == "0.0" or test == "0":
                setattr(self, trait, float(value))
            elif value:
                is_float = self.trait(trait).is_trait_type(Float)
                if is_float:
                    setattr(self, trait, float(value))
                else:
                    setattr(self, trait, str(value))

        self.field_size = self.get_field_size()
        self.scan_type = self.get_scan_type()
        depth = self.get_scan_depth()
        if depth == '-':
            self.depth = numpy.NaN
        else:
            self.depth = float(depth)
Example #11
0
class FieldViewer(HasTraits):
    """三维标量场观察器"""

    # 三个轴的取值范围
    x0, x1 = Float(-5), Float(5)
    y0, y1 = Float(-5), Float(5)
    z0, z1 = Float(-5), Float(5)
    points = Int(50)  # 分割点数
    autocontour = Bool(True)  # 是否自动计算等值面
    v0, v1 = Float(0.0), Float(1.0)  # 等值面的取值范围
    contour = Range("v0", "v1", 0.5)  # 等值面的值
    function = Str("x*x*0.5 + y*y + z*z*2.0")  # 标量场函数
    function_list = [
        "x*x*0.5 + y*y + z*z*2.0", "x*y*0.5 + sin(2*x)*y +y*z*2.0", "x*y*z",
        "np.sin((x*x+y*y)/z)"
    ]
    plotbutton = Button("描画")
    scene = Instance(MlabSceneModel, ())  # mayavi场景

    view = View(
        HSplit(
            VGroup(
                "x0",
                "x1",
                "y0",
                "y1",
                "z0",
                "z1",
                Item('points', label="点数"),
                Item('autocontour', label="自动等值"),
                Item('plotbutton', show_label=False),
            ),
            VGroup(
                Item(
                    'scene',
                    editor=SceneEditor(
                        scene_class=MayaviScene),  # 设置mayavi的编辑器
                    resizable=True,
                    height=300,
                    width=350),
                Item('function',
                     editor=EnumEditor(name='function_list',
                                       evaluate=lambda x: x)),
                Item('contour',
                     editor=RangeEditor(format="%1.2f",
                                        low_name="v0",
                                        high_name="v1")),
                show_labels=False)),
        width=500,
        resizable=True,
        title="三维标量场观察器")

    def _plotbutton_fired(self):
        self.plot()

    def _autocontour_changed(self):
        "自动计算等值平面的设置改变事件响应"
        if hasattr(self, "g"):
            self.g.contour.auto_contours = self.autocontour
            if not self.autocontour:
                self._contour_changed()

    def _contour_changed(self):
        "等值平面的值改变事件响应"
        if hasattr(self, "g"):
            if not self.g.contour.auto_contours:
                self.g.contour.contours = [self.contour]

    def plot(self):
        "绘制场景"
        # 产生三维网格
        x, y, z = np.mgrid[self.x0:self.x1:1j * self.points,
                           self.y0:self.y1:1j * self.points,
                           self.z0:self.z1:1j * self.points]

        # 根据函数计算标量场的值
        scalars = eval(self.function)
        mlab.clf()  # 清空当前场景

        # 绘制等值平面
        g = mlab.contour3d(x, y, z, scalars, contours=8, transparent=True)
        g.contour.auto_contours = self.autocontour
        mlab.axes()  # 添加坐标轴

        # 添加一个X-Y的切面
        s = mlab.pipeline.scalar_cut_plane(g)
        cutpoint = (self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2, (
            self.z0 + self.z1) / 2
        s.implicit_plane.normal = (0, 0, 1)  # x cut
        s.implicit_plane.origin = cutpoint

        self.g = g
        self.scalars = scalars
        # 计算标量场的值的范围
        self.v0 = np.min(scalars)
        self.v1 = np.max(scalars)
Example #12
0
class ClampConst(HasTraits):

    # distance between the Einspannklemme and Kraftschlussklemme
    l1 = Float(0.137, modified=True, auto_set=False, enter_set=True)

    #tested length
    lt = Float(0.1, modified=True, auto_set=False, enter_set=True)

    # length Kraftschlussklemme
    Lk = Float(0.05, modified=True, auto_set=False, enter_set=True)

    # yarn module of elasticity
    E = Float(72e9, modified=True, auto_set=False, enter_set=True)

    # yarn cross-sectional area
    A = Float(0.89e-6, modified=True, auto_set=False, enter_set=True)

    # current tensile force
    Ft = Range(0., 1000., modified=True, auto_set=False, enter_set=True)

    # compression dependent shear flow per length - Kraftschlussklemme
    qk = Float(2200., modified=True, auto_set=False, enter_set=True)

    # pressure Kraftschlussklemme
    Pk = Float(6.0, modified=True, auto_set=False, enter_set=True)

    # additional force on the Kraftschlussklemme
    Fa = Range(10.0, 700.0, modified=True, auto_set=False, enter_set=True)

    # prestressing force
    P0 = Float(100., modified=True, auto_set=False, enter_set=True)

    values = Tuple(Array, Array)

    def strains(self):
        mfn = MFnLineArray()
        mfn.xdata, mfn.ydata = self.values
        strains_fn = frompyfunc(mfn.get_diff, 1, 1)
        strains = strains_fn(mfn.xdata)
        strains[0] = strains[1]
        strains[-2] = strains[-1]
        return strains

    def get_values(self):
        l1 = self.l1
        lt = self.lt
        A = self.A
        E = self.E
        Ft = self.Ft
        Fa = self.Fa
        Pk = self.Pk
        qk = self.qk
        Lk = self.Lk
        P0 = self.P0

        # before the add-clamp moves axially
        xdata = linspace(0, l1 + Lk + lt, 200)

        def ydata1(x):
            y0 = min((l1 + Lk - x) * qk * Pk, P0, Lk * qk * Pk)
            y1 = min((Ft - Fa), (Ft - Fa) - qk * Pk * (x - l1))
            return max(y0, y1)

        py_y1 = frompyfunc(ydata1, 1, 1)
        y1 = py_y1(xdata)

        def ydata2(x):
            y2 = Ft
            y3 = (x - Lk - l1) * qk * Pk + Ft
            return min(y2, y3)

        py_y2 = frompyfunc(ydata2, 1, 1)
        y2 = py_y2(xdata)

        y1 = y1 * H(y1)
        y2 = y2 * H(y2)

        f = frompyfunc(lambda x, y: max(x, y), 2, 1)
        y = f(y1, y2)

        self.values = xdata, y

        return self.values

    traits_view = View(
        Item('l1', label='fixation to clamping distance [m]'),
        Item('lt', label='tested length [m]'),
        Item('Lk', label='clamp length [m]'),
        Item('E', label='yarn elasticity modulus [N/m2]'),
        Item('A', label='yarn cross-section [m2]'),
        Item('Ft', label='current tensile force [N]'),
        Item('qk', label='frictional shear force per pressure [m-1]'),
        Item('Pk', label='applied force to the clamp [N]'),
        Item('Fa', label='additional axial force [N]'),
        Item('P0', label='prestressing force [N]'),
    )
Example #13
0
class OverlayMap(HasTraits):
    """
    Use mayavi to plot three image cut planes through an fMRI volume
    and stat-map overlay.
    """
    
    # Main scene
    scene = Instance(MlabSceneModel, ())

    # the image planes and lookup tables
    overlays = List(Instance(ImagePlaneWidget))
    underlays = List(Instance(ImagePlaneWidget))
    over_lut = Instance(HasTraits)
    under_lut = Instance(HasTraits)

    # lower range of the overlay lookup table
    _over_low_min = Float(0.0)
    _over_low_max = Float(0.1)
    over_low = Range(low='_over_low_min', high='_over_low_max',
                     value=0.0, mode='slider')

    # upper range of the overlay lookup table
    _over_hi_min = Float(0.0)
    _over_hi_max = Float(0.1)
    over_hi = Range(low='_over_hi_min', high='_over_hi_max',
                     value=0.01, mode='slider')


    # Whether to see x,y,z planes
    x_visible = Bool(True)
    y_visible = Bool(True)
    z_visible = Bool(True)

    # Which colormap to use 
    colormap = Enum("hot", 
                    "jet",
                    "autumn")
    
    def __init__(self, under_image, over_image):
        """
        Provide the underlay and overlay NiftiImages.  Can also
        provide filename strings.

        Example:

        stat = OverlayMap('anat.nii.gz','stat.nii.gz')
        """
        # we've got traits
        HasTraits.__init__(self)

        # load in the image
        if isinstance(under_image, NiftiImage):
            # use it
            self.__under_image = under_image
        elif isinstance(under_image, str):
            # load from file
            self.__under_image = NiftiImage(under_image)
        else:
            raise ValueError("under_image must be a NiftiImage or a file.")

        # TODO: set the extent and spacing of the under image

        # set the over data
        if isinstance(over_image, str):
            # load from file
            over_image = NiftiImage(over_image)

        if isinstance(over_image, NiftiImage):
            # TODO: make sure it matches the dims of under image
            # TODO: set the extent
            
            # save just the dat
            self.__over_image = over_image.data.T

        elif isinstance(over_image, np.ndarray):
            # just set it
            # assumes it matches the dims and extent of the under image
            self.__over_image = over_image

        else:
            raise ValueError("over_image must be a NiftiImage, ndarray, or file.")

        self.__over_image = np.ma.masked_invalid(self.__over_image)

        self.configure_traits()
        pass

    def _plane_callback1(self, widget, event):
	    self._update_planes(0)

    def _plane_callback2(self, widget, event):
	    self._update_planes(1)

    def _plane_callback3(self, widget, event):
	    self._update_planes(2)

    def _update_planes(self,num):
        # set the underlay positions.
        
        # TODO: it may make more sense to do this in the callback for
        # each individual plane when it is called instead of all at
        # once

        #for i in range(len(self.overlays)):
	#if widget == self.overlays[i]:
            #    print "widget is overlay", i
	    #elif widget == self.underlays[i]:
            #    print "widget is underlay", i
	    #else:
	    #    print "widget", widget
                

            # from what I can tell, all these are necessary
        self.underlays[num].ipw.update_traits()
        self.overlays[num].ipw.origin = self.underlays[num].ipw.origin
        self.overlays[num].ipw.point1 = self.underlays[num].ipw.point1
        self.overlays[num].ipw.point2 = self.underlays[num].ipw.point2
        self.overlays[num].ipw.update_traits()
        self.overlays[num].ipw.update_placement()
        #self.overlays[num].scene.render()
                      
    @on_trait_change('scene.activated')
    def _create_plot(self):
        # shorten things a bit
        mlab = self.scene.mlab

        # generate the scalar_fields
        over = mlab.pipeline.scalar_field(np.ma.masked_invalid(self.__over_image).filled(0))
        #over_thresh = mlab.pipeline.threshold(over,low=self.__over_image.mean())
        under = mlab.pipeline.scalar_field(np.ma.masked_invalid(self.__under_image.data.T).filled(0))

        # create the planes for the x,y,z axes
        self.underlays = []
        self.overlays = []
        for orient in ['x_axes','y_axes','z_axes']:
            # first the underlay
            # TODO: fix the slice_index, which is a hack
            under = mlab.pipeline.image_plane_widget(under,colormap='gray',
                                                     slice_index=92,
                                                     plane_opacity=0,
                                                     plane_orientation=orient)
            # set up the lookup table
            under.ipw.user_controlled_lookup_table = True
            if self.under_lut is None:
                # set it
                self.under_lut = under.module_manager.scalar_lut_manager.lut
            else:
                # use it
                under.module_manager.scalar_lut_manager.lut.table = self.under_lut.table

            # add the interaction event
	    if orient == "x_axes":
                under.ipw.add_observer("InteractionEvent", self._plane_callback1)
	    elif orient == "y_axes":
                under.ipw.add_observer("InteractionEvent", self._plane_callback2)
	    else:
                under.ipw.add_observer("InteractionEvent", self._plane_callback3)
            
            # add it to the list
            self.underlays.append(under)

            # set up the overlay
            # TODO: fix the slice_index, which is a hack
            over = mlab.pipeline.image_plane_widget(over,
                                                    colormap=self.colormap,
                                                    slice_index=92,
                                                    plane_opacity=0,
                                                    plane_orientation=orient)
            # set the lookup table
            over.ipw.user_controlled_lookup_table = True
            if self.over_lut is None:
                # is first one, so set it with alpha at bottom
                lut = over.module_manager.scalar_lut_manager.lut.table.to_array()
                lut[:40, -1] = np.linspace(0,255,40)
                over.module_manager.scalar_lut_manager.lut.table = lut
                self.over_lut = over.module_manager.scalar_lut_manager.lut
            else:
                # use it
                over.module_manager.scalar_lut_manager.lut.table = self.over_lut.table

            # turn off the interaction
            over.ipw.interaction = False
            # append 
            self.overlays.append(over)

        # set the overlay upper bounds range
        over_min = np.ma.masked_invalid(self.__over_image).min()
        over_max = np.ma.masked_invalid(self.__over_image).max()
        over_mean = np.ma.masked_invalid(self.__over_image).mean()
        print "mmm:", over_min, over_max, over_mean
        self._over_hi_min = float(over_min) #self.__over_image.min()
        self._over_hi_max = float(over_max) #self.__over_image.max()
        self.over_hi = over_max #self.__over_image.max()

        # set the overlay lower bounds range
        self._over_low_min = float(over_min) #self.__over_image.min()
        self._over_low_max = float(over_max) #self.__over_image.max()
        self.over_low = over_mean #self.__over_image.mean()

    #@on_trait_change('over_hi')
    def _over_hi_changed(self):
        if self.over_hi < self.over_low:
            # set low to be hi
            self.over_low = self.over_hi
        else:
            # update the range
            self._update_overlay_range()
        
    #@on_trait_change('over_low')
    def _over_low_changed(self):
        if self.over_low > self.over_hi:
            # set hi to be low
            self.over_hi = self.over_low
        else:
            # update the range
            self._update_overlay_range()
        
    def _update_overlay_range(self):
        # XXX: Do I need to copy here?
        new_range = self.overlays[0].module_manager.scalar_lut_manager.data_range.copy()
        new_range[0] = self.over_low
        new_range[1] = self.over_hi
        for i in range(len(self.overlays)):
            self.overlays[i].module_manager.scalar_lut_manager.data_range = new_range

    def _x_visible_changed(self):
        # toggle the proper plane on or off
        self._update_plane_visible(0,self.x_visible)
    def _y_visible_changed(self):
        # toggle the proper plane on or off
        self._update_plane_visible(1,self.y_visible)
    def _z_visible_changed(self):
        # toggle the proper plane on or off
        self._update_plane_visible(2,self.z_visible)

    def _update_plane_visible(self, plane_id, bool_val):
        self.underlays[plane_id].visible = bool_val
        self.overlays[plane_id].visible = bool_val

    def _colormap_changed(self):
        print self.colormap
        for o in self.overlays:
            #TODO: Change colormap, don't know how exactly
            #print o.module_manager.scalar_lut_manager.lut
            pass

    # define the view
    view = View(
        VSplit(
            Group(Item('scene', editor=SceneEditor(scene_class=MayaviScene), 
                       height=500, width=500, show_label=False)),
            Group(
                Group(Item('over_low', label="Lower Thresh"),
                      Item('over_hi', label="Upper Thresh"),
                      label="Overlay Properties",
                      show_border=True),
            ),
            Group(
                HGroup(Item('x_visible'),
                       Item('y_visible'),
                       Item('z_visible'),
                       Item('colormap'),
                       label="Plane visibility + colormap",
                       show_border=True),
            ),
        ),
        resizable=True,
        title='Overlay Viewer')
Example #14
0
class Generator(HasTraits):
    meanIntensity = Float(500)
    meanDuration = Float(3)
    backgroundIntensity = Float(300)
    meanEventNumber = Float(2)
    scaleFactor = Float(2)
    meanTime= Float(2000)

    sources = List([WormlikeSource(), ImageSource(), FileSource()])

    source = Instance(PointSource)

    traits_view = View( Item( 'source',
                            label= 'Point source',
                            editor =
                            InstanceEditor(name = 'sources',
                                editable = True),
                                ),
                        Item('_'),
                        Item('meanIntensity'),
                        Item('meanDuration'),
                        Item('meanEventNumber'),
                        Item('meanTime'),
                        Item('_'),
                        Item('backgroundIntensity'),
                        
                        buttons = ['OK'])

    def __init__(self, visFr = None):
        self.visFr = visFr
        self.source = self.sources[0]

        if visFr:
            ID_GEN_POINTS = wx.NewId()
            ID_CONF_SIMUL = wx.NewId()
            ID_GEN_EVENTS = wx.NewId()

            mSimul = wx.Menu()

            mSimul.Append(ID_CONF_SIMUL, "Configure")
            mSimul.Append(ID_GEN_POINTS, "Generate fluorophore positions and events")
            mSimul.Append(ID_GEN_EVENTS, "Generate events")

            visFr.extras_menu.AppendSubMenu(mSimul, 'Synthetic Data')

            visFr.Bind(wx.EVT_MENU, self.OnGenPoints, id=ID_GEN_POINTS)
            visFr.Bind(wx.EVT_MENU, self.OnGenEvents, id=ID_GEN_EVENTS)
            visFr.Bind(wx.EVT_MENU, self.OnConfigure, id=ID_CONF_SIMUL)

    def OnConfigure(self, event):
        self.source.refresh_choices()
        self.edit_traits()

    def OnGenPoints(self, event):
        self.xp, self.yp = self.source.getPoints()
        self.OnGenEvents(None)

    def OnGenEvents(self, event):
        from PYMEnf.Simulation import locify
        #from PYME.Acquire.Hardware.Simulator import wormlike2
        from PYME.Analysis.LMVis import inpFilt
        from PYME.Analysis.LMVis.visHelpers import ImageBounds
        import pylab
        
        #wc = wormlike2.wormlikeChain(100)
        
        pipeline = self.visFr.pipeline
        pipeline.filename='Simulation'

        pylab.figure()
        pylab.plot(self.xp, self.yp, 'x') #, lw=2)
        if isinstance(self.source, WormlikeSource):
            pylab.plot(self.xp, self.yp, lw=2)

        res = locify.eventify(self.xp, self.yp, self.meanIntensity, self.meanDuration, self.backgroundIntensity, self.meanEventNumber, self.scaleFactor, self.meanTime)
        pylab.plot(res['fitResults']['x0'],res['fitResults']['y0'], '+')

        pipeline.selectedDataSource = inpFilt.mappingFilter(inpFilt.fitResultsSource(res))
        pipeline.imageBounds = ImageBounds.estimateFromSource(pipeline.selectedDataSource)
        pipeline.dataSources.append(pipeline.selectedDataSource)

        from PYME.Acquire.MetaDataHandler import NestedClassMDHandler
        pipeline.mdh = NestedClassMDHandler()
        pipeline.mdh['Camera.ElectronsPerCount'] = 1
        pipeline.mdh['Camera.TrueEMGain'] = 1
        pipeline.mdh['Camera.CycleTime'] = 1
        pipeline.mdh['voxelsize.x'] = .110

        try:
            pipeline.filterKeys.pop('sig')
        except:
            pass
        self.visFr.RegenFilter()
        self.visFr.SetFit()
Example #15
0
    class SegmentorWSiter(SegmentorBase):
        name = "Supervoxel Segmentation"
        description = "Segmentation plug-in using sparse basin graph"
        author = "C. N. Straehle, HCI - University of Heidelberg"
        homepage = "http://hci.iwr.uni-heidelberg.de"

        dontUseSuperVoxels = Bool(False)
        edgeWeights = Enum("Average", "Difference")
        algorithm = Enum("Watershed", "Graphcut", "Randomwalk")        
        bias = Float(0.95)
        biasThreshold = Float(128)
        biasedLabel = Int(1)
        sigma = Float(0.2)
        lis_options = String("-i bicgstab -tol 1.0e-9")
        
        viewWS = Group(Item('bias'),Item('biasThreshold'),  Item('biasedLabel'), visible_when = 'algorithm=="Watershed"')
        viewRW = Group(Item('sigma'), Item('lis_options'), visible_when = 'algorithm=="Randomwalk"')
        viewGC = Group(Item('sigma'), visible_when = 'algorithm=="Graphcut"')

        view = View( Item('edgeWeights'), Item('dontUseSuperVoxels'), Item('algorithm'), buttons = ['OK', 'Cancel'],  )        

        inlineConfig = View(Item('algorithm'), Group(viewWS, viewRW, viewGC))
        default = View(Item('bias'))
        
        lastBorderState = False        
        
                
#*******************************************************************************
# I n d e x e d A c c e s s o r                                                *
#*******************************************************************************

        class IndexedAccessor:
            """
            Helper class that behaves like an ndarray, but does a Lookuptable access
            """

            def __init__(self, volumeBasins, basinLabels):
                self.volumeBasins = volumeBasins
                self.basinLabels = basinLabels
                self.dtype = basinLabels.dtype
                self.shape = volumeBasins.shape
                self.flat = None
                print "Indexaccessor:", volumeBasins.shape
            def __getitem__(self, key):
                return self.basinLabels[self.volumeBasins[tuple(key)]]

            def __setitem__(self, key, data):
                #self.data[tuple(key)] = data
                print "##########ERROR ######### : IndexAccessor setitem should not be called"

#*******************************************************************************
# I n d e x e d A c c e s s o r W i t h C h a n n e l                          *
#*******************************************************************************

        class IndexedAccessorWithChannel:
            """
            Helper class that behaves like an ndarray, but does a Lookuptable access
            """

            def __init__(self, volumeBasins, basinLabels):
                self.volumeBasins = volumeBasins
                self.basinLabels = basinLabels
                self.dtype = basinLabels.dtype
                self.shape = volumeBasins.shape[:-1] + (basinLabels.shape[1],)
                
            def __getitem__(self, key):
                return self.basinLabels[:,key[-1]][self.volumeBasins[tuple(key[:-1])]]

            def __setitem__(self, key, data):
                #self.data[tuple(key)] = data
                print "##########ERROR ######### : IndexAccessor setitem should not be called"



        def segment3D(self, labelVolume, labelValues, labelIndices):
            print "setting seeds"
            self.segmentor.setSeeds(labelValues, labelIndices)
            if self.algorithm == "Graphcut":
                print "Executing Graphcut with sigma = %d"  % (self.sigma,)
                self.basinLabels = self.segmentor.doGC(self.sigma)
            elif self.algorithm == "Watershed":
                print "Executing Watershed with bias %d and biasedLabel %d" % (self.bias,  self.biasedLabel,)
                self.basinLabels = self.segmentor.doWS(self.bias,  self.biasThreshold, self.biasedLabel)
            elif self.algorithm == "Randomwalk":
                print "Executing Random Walk with sigma %f, and lis options %s" % (self.sigma,  self.lis_options,)
                self.basinLabels = self.segmentor.doRW(self.sigma,  self.lis_options)
                
                self.basinPotentials = self.segmentor.getBasinPotentials()
                
                self.potentials = SegmentorWSiter.IndexedAccessorWithChannel(self.volumeBasins,self.basinPotentials)
                
            self.getBasins()
            
            self.segmentation = SegmentorWSiter.IndexedAccessor(self.volumeBasins, self.basinLabels)
            return self.segmentation

        def segment2D(self, labels):
            pass

        @on_trait_change('dontUseSuperVoxels')
        def recalculateWeights(self):
            self.setupWeights(self.weights)

        def setupWeights(self, weights):
            self.weights = weights
            #self.weights = numpy.average(weights, axis = 3).astype(numpy.uint8)#.swapaxes(0,2).view(vigra.ScalarVolume)#
            if weights.dtype != numpy.uint8:
                print "converting weights to uint8"
                self.weights = weights.astype(numpy.uint8)
                
#            self.weights = numpy.zeros(weights.shape[0:-1], 'uint8')
#            self.weights[:] = 3
#            self.weights[:,:,0::4] = 10
#            self.weights[:,0::4,:] = 10
#            self.weights[0::4,:,:] = 10
#            self.weights = self.weights

            #self.ws = vigra.tws.IncrementalWS(self.weights, 0)
            print "Incoming weights :", self.weights.dtype, self.weights.shape

            if hasattr(self, "segmentor"):
                del self.segmentor
                del self.volumeBasins
            

            if self.edgeWeights == "Difference":
                useDifference = True
            else:                
                useDifference = False
            
            #print self.dontUseSuperVoxels
            self.segmentor = vigra.svs.segmentor(self.weights, useDifference, 0, 255, 2048, self.dontUseSuperVoxels)
    
            
            self.getBasins()
            self.volumeBasins.shape = self.volumeBasins.shape + (1,)

            self.borders = self.segmentor.getBorderVolume()   
            self.borders.shape = self.borders.shape + (1,)
            #self.borders = self.volumeBasins
            

        def getBasins(self):
            self.volumeBasins = self.segmentor.getVolumeBasins()            
                
            
Example #16
0
class SegmentPlot(BaseXYPlot):
    """ A plot consisting of disconnected line segments.
    """

    # The color of the line.
    color = black_color_trait

    # The color to use to highlight the line when selected.
    selected_color = ColorTrait("lightyellow")

    # The style of the selected line.
    selected_line_style = LineStyle("solid")

    # The name of the key in self.metadata that holds the selection mask
    metadata_name = Str("selections")

    # The thickness of the line.
    line_width = Float(1.0)

    # The line dash style.
    line_style = LineStyle

    # Traits UI View for customizing the plot.
    traits_view = tui.View(tui.Item("color", style="custom"),
                           "line_width",
                           "line_style",
                           buttons=["OK", "Cancel"])

    #------------------------------------------------------------------------
    # Private traits
    #------------------------------------------------------------------------

    # Cached list of non-NaN arrays of (x,y) data-space points; regardless of
    # self.orientation, this is always stored as (index_pt, value_pt).  This is
    # different from the default BaseXYPlot definition.
    _cached_data_pts = List

    # Cached list of non-NaN arrays of (x,y) screen-space points.
    _cached_screen_pts = List

    def hittest(self, screen_pt, threshold=7.0):
        # NotImplemented
        return None

    def get_screen_points(self):
        self._gather_points()
        return [self.map_screen(ary) for ary in self._cached_data_pts]

    #------------------------------------------------------------------------
    # Private methods; implements the BaseXYPlot stub methods
    #------------------------------------------------------------------------

    def _gather_points(self):
        """
        Collects the data points that are within the bounds of the plot and 
        caches them.
        """
        if self._cache_valid or not self.index or not self.value:
            return

        index = self.index.get_data()
        value = self.value.get_data()

        # Check to see if the data is completely outside the view region
        for ds, rng in ((self.index, self.index_range), (self.value,
                                                         self.value_range)):
            low, high = ds.get_bounds()
            if low > rng.high or high < rng.low:
                return

        if len(index) == 0 or len(value) == 0 or len(index) != len(value):
            self._cached_data_pts = []
            self._cache_valid = True

        size_diff = len(value) - len(index)
        if size_diff > 0:
            warnings.warn('len(value) %d - len(index) %d = %d' \
                          % (len(value), len(index), size_diff))
            index_max = len(index)
            value = value[:index_max]
        else:
            index_max = len(value)
            index = index[:index_max]
        if index_max % 2:
            # We need an even number of points. Exclude the final one and
            # continue.
            warnings.warn('need an even number of points; got %d' % index_max)
            index = index[:index_max - 1]
            value = value[:index_max - 1]

        # TODO: restore the functionality of rendering highlighted portions
        # of the line
        #selection = self.index.metadata.get(self.metadata_name, None)
        #if selection is not None and type(selection) in (ndarray, list) and \
        #        len(selection) > 0:

        # Exclude NaNs and Infs.
        finite_mask = np.isfinite(value) & np.isfinite(index)
        # Since the line segment ends are paired, we need to exclude the whole pair if
        # one is not finite.
        finite_mask[::2] &= finite_mask[1::2]
        finite_mask[1::2] &= finite_mask[::2]
        self._cached_data_pts = [
            np.column_stack([index[finite_mask], value[finite_mask]])
        ]
        self._cache_valid = True

    def _render(self, gc, points, selected_points=None):
        if len(points) == 0:
            return

        gc.save_state()
        try:
            gc.set_antialias(True)
            gc.clip_to_rect(self.x, self.y, self.width, self.height)

            if selected_points is not None:
                self._render_segments(gc, selected_points,
                                      self.selected_color_,
                                      self.line_width + 10.0,
                                      self.selected_line_style_)

            # Render using the normal style
            self._render_segments(gc, points, self.color_, self.line_width,
                                  self.line_style_)
        finally:
            gc.restore_state()

    def _render_segments(self, gc, points, color, line_width, line_style):
        gc.set_stroke_color(color)
        gc.set_line_width(line_width)
        gc.set_line_dash(line_style)
        gc.begin_path()
        for ary in points:
            if len(ary) > 0:
                gc.line_set(ary[::2], ary[1::2])
        gc.stroke_path()

    @on_trait_change('color,line_style,line_width')
    def _redraw(self):
        self.invalidate_draw()
        self.request_redraw()
Example #17
0
class Text(Module):
    # The version of this class.  Used for persistence.
    __version__ = 0

    # The tvtk TextActor.
    actor = Instance(tvtk.TextActor, allow_none=False, record=True)

    # The property of the axes (color etc.).
    property = Property(record=True)

    # The text to be displayed.  Note that this should really be `Str`
    # but wxGTK only returns unicode.
    text = Str('Text', desc='the text to be displayed')

    # The x-position of this actor.
    x_position = Float(0.0, desc='the x-coordinate of the text')

    # The y-position of this actor.
    y_position = Float(0.0, desc='the y-coordinate of the text')

    # The z-position of this actor.
    z_position = Float(0.0, desc='the z-coordinate of the text')

    # Shadow the positions as ranges for 2D. Simply using a RangeEditor
    # does not work as it resets the 3D positions to 1 when the dialog is
    # loaded.
    _x_position_2d = Range(0.,
                           1.,
                           0.,
                           enter_set=True,
                           auto_set=False,
                           desc='the x-coordinate of the text')
    _y_position_2d = Range(0.,
                           1.,
                           0.,
                           enter_set=True,
                           auto_set=False,
                           desc='the y-coordinate of the text')

    # 3D position
    position_in_3d = Bool(
        False,
        desc='whether the position of the object is given in 2D or in 3D')

    # The width of the text.
    width = Range(0.0,
                  1.0,
                  0.4,
                  enter_set=True,
                  auto_set=False,
                  desc='the width of the text as a fraction of the viewport')

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    ########################################
    # The view of this object.

    if VTK_VER > 5.1:
        _text_actor_group = Group(Item(name='visibility'),
                                  Item(name='text_scale_mode'),
                                  Item(name='alignment_point'),
                                  Item(name='minimum_size'),
                                  Item(name='maximum_line_height'),
                                  show_border=True,
                                  label='Text Actor')
    else:
        _text_actor_group = Group(Item(name='visibility'),
                                  Item(name='scaled_text'),
                                  Item(name='alignment_point'),
                                  Item(name='minimum_size'),
                                  Item(name='maximum_line_height'),
                                  show_border=True,
                                  label='Text Actor')

    _position_group_2d = Group(Item(name='_x_position_2d', label='X position'),
                               Item(name='_y_position_2d', label='Y position'),
                               visible_when='not position_in_3d')

    _position_group_3d = Group(Item(name='x_position', label='X',
                                    springy=True),
                               Item(name='y_position', label='Y',
                                    springy=True),
                               Item(name='z_position', label='Z',
                                    springy=True),
                               show_border=True,
                               label='Position',
                               orientation='horizontal',
                               visible_when='position_in_3d')

    view = View(Group(Group(Item(name='text'),
                            Item(name='position_in_3d'),
                            _position_group_2d,
                            _position_group_3d,
                            Item(name='width',
                                 enabled_when='object.actor.scaled_text'),
                            ),
                      Group(Item(name='actor', style='custom',
                                 editor=\
                                 InstanceEditor(view=View(_text_actor_group))
                                 ),
                            show_labels=False),
                      label='TextActor',
                      show_labels=False
                      ),
                Group(Item(name='_property', style='custom', resizable=True),
                      label='TextProperty',
                      show_labels=False),
                )

    ########################################
    # Private traits.
    _updating = Bool(False)
    _property = Instance(tvtk.TextProperty)

    ######################################################################
    # `object` interface
    ######################################################################
    def __set_pure_state__(self, state):
        self._updating = True
        state_pickler.set_state(self,
                                state,
                                first=['actor'],
                                ignore=['_updating'])
        self._updating = False

    ######################################################################
    # `Module` interface
    ######################################################################
    def setup_pipeline(self):
        """Override this method so that it *creates* the tvtk
        pipeline.

        This method is invoked when the object is initialized via
        `__init__`.  Note that at the time this method is called, the
        tvtk data pipeline will *not* yet be setup.  So upstream data
        will not be available.  The idea is that you simply create the
        basic objects and setup those parts of the pipeline not
        dependent on upstream sources and filters.  You should also
        set the `actors` attribute up at this point.
        """
        actor = self.actor = tvtk.TextActor(input=str(self.text))
        if VTK_VER > 5.1:
            actor.set(text_scale_mode='prop', width=0.4, height=1.0)
        else:
            actor.set(scaled_text=True, width=0.4, height=1.0)

        c = actor.position_coordinate
        c.set(coordinate_system='normalized_viewport',
              value=(self.x_position, self.y_position, 0.0))
        c = actor.position2_coordinate
        c.set(coordinate_system='normalized_viewport')

        self._property.opacity = 1.0

        self._text_changed(self.text)
        self._width_changed(self.width)
        self._shadow_positions(True)

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when any of the inputs
        sends a `pipeline_changed` event.
        """
        self.pipeline_changed = True

    def update_data(self):
        """Override this method so that it flushes the vtk pipeline if
        that is necessary.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        # Just set data_changed, the component should do the rest.
        self.data_changed = True

    ######################################################################
    # Non-public interface
    ######################################################################
    def _text_changed(self, value):
        actor = self.actor
        if actor is None:
            return
        if self._updating:
            return
        actor.input = str(value)
        self.render()

    def _shadow_positions(self, value):
        self.sync_trait('x_position',
                        self,
                        '_x_position_2d',
                        remove=(not value))
        self.sync_trait('y_position',
                        self,
                        '_y_position_2d',
                        remove=(not value))
        if not value:
            self._x_position_2d = self.x_position
            self._y_position_2d = self.y_position

    def _position_in_3d_changed(self, value):
        if value:
            self.actor.position_coordinate.coordinate_system = 'world'
            self.actor.position2_coordinate.coordinate_system = 'world'
        else:
            self.actor.position2_coordinate.coordinate_system=\
                                            'normalized_viewport'
            self.actor.position_coordinate.coordinate_system=\
                                            'normalized_viewport'
            x = self.x_position
            y = self.y_position
            if x < 0:
                x = 0
            elif x > 1:
                x = 1
            if y < 0:
                y = 0
            elif y > 1:
                y = 1
            self.set(x_position=x, y_position=y, trait_change_notify=False)
        self._shadow_positions(not value)
        self._change_position()
        self.actor._width_changed(self.width, self.width)
        self.pipeline_changed = True

    def _change_position(self):
        """ Callback for _x_position, _y_position, and z_position.
        """
        actor = self.actor
        if actor is None:
            return
        if self._updating:
            return
        x = self.x_position
        y = self.y_position
        z = self.z_position
        if self.position_in_3d:
            actor.position_coordinate.value = x, y, z
        else:
            actor.position = x, y
        self.render()

    _x_position_changed = _change_position

    _y_position_changed = _change_position

    _z_position_changed = _change_position

    def _width_changed(self, value):
        actor = self.actor
        if actor is None:
            return
        if self._updating:
            return
        actor.width = value
        self.render()

    def _update_traits(self):
        self._updating = True
        try:
            actor = self.actor
            self.text = actor.input
            pos = actor.position
            self.x_position, self.y_position = pos
            self.width = actor.width
        finally:
            self._updating = False

    def _get_property(self):
        return self._property

    def _actor_changed(self, old, new):
        if old is not None:
            for obj in (old, self._property):
                obj.on_trait_change(self.render, remove=True)
            old.on_trait_change(self._update_traits, remove=True)

        self._property = new.text_property
        for obj in (new, self._property):
            obj.on_trait_change(self.render)
        new.on_trait_change(self._update_traits)

        self.actors = [new]
        self.render()

    def _foreground_changed_for_scene(self, old, new):
        # Change the default color for the actor.
        self.property.color = new
        self.render()

    def _scene_changed(self, old, new):
        super(Text, self)._scene_changed(old, new)
        self._foreground_changed_for_scene(None, new.foreground)
Example #18
0
class Uncertain(HasStrictTraits):
    """
    Represents a numeric value with a known small uncertainty 
    (error, standard deviation...).
    Numeric operators are overloaded to work with other Uncertain or
    numeric objects.
    The uncertainty (error) must be small. Otherwise the linearization
    employed here becomes wrong.
    The usage of traits can easily be dumped.
    """
    value = Float
    error = Float(0.)

    def __init__(self, value=0., error=0., *a, **t):
        self.value = value
        self.error = abs(error)
        super(Uncertain, self).__init__(*a, **t)

    def __str__(self):
        return "%g+-%g" % (self.value, self.error)

    def __repr__(self):
        return "Uncertain(%s, %s)" % (self.value, self.error)

    def __float__(self):
        return self.value

    def assign(self, other):
        if isinstance(other, Uncertain):
            self.value = other.value
            self.error = other.error
        else:
            self.value = other
            self.error = 0.

    def __abs__(self):
        return Uncertain(abs(self.value), self.error)

    def __add__(self, other):
        if isinstance(other, Uncertain):
            v = self.value + other.value
            e = (self.error**2 + other.error**2)**.5
            return Uncertain(v, e)
        else:
            return Uncertain(self.value + other, self.error)

    def __radd__(self, other):
        return self + other  # __add__

    def __sub__(self, other):
        return self + (-other)  # other.__neg__ and __add__

    def __rsub__(self, other):
        return -self + other  # __neg__ and __add__

    def __mul__(self, other):
        if isinstance(other, Uncertain):
            v = self.value * other.value
            e = ((self.error * other.value)**2 +
                 (other.error * self.value)**2)**.5
            return Uncertain(v, e)
        else:
            return Uncertain(self.value * other, self.error * other)

    def __rmul__(self, other):
        return self * other  # __mul__

    def __neg__(self):
        return self * -1  # __mul__

    def __pos__(self):
        return self

    def __div__(self, other):
        return self * (1. / other)  # other.__div__ and __mul__

    def __rdiv__(self, other):
        return (self / other)**-1.  # __pow__ and __div__

    def __pow__(self, other):
        if isinstance(other, Uncertain):
            v = self.value**other.value
            e = ((self.error * other.value * self.value**(other.value - 1.))**2
                 + (other.error * np.log(self.value) * self.value**other.value)
                 **2)**.5
            return Uncertain(v, e)
        else:
            return Uncertain(self.value**other,
                             self.error * other * self.value**(other - 1))

    def __rpow__(self, other):
        assert not isinstance(other, Uncertain)
        # otherwise other.__pow__ would have been called
        return Uncertain(other**self.value,
                         self.error * np.log(other) * other**self.value)

    def exp(self):
        return np.e**self

    def log(self):
        return Uncertain(np.log(self.value), self.error / self.value)
Example #19
0
class Threshold(Filter):

    # The version of this class.  Used for persistence.
    __version__ = 0

    # The threshold filter used.
    threshold_filter = Property(Instance(tvtk.Object, allow_none=False),
                                record=True)

    # The filter type to use, specifies if the cells or the points are
    # cells filtered via a threshold.
    filter_type = Enum('cells',
                       'points',
                       desc='if thresholding is done on cells or points')

    # Lower threshold (this is a dynamic trait that is changed when
    # input data changes).
    lower_threshold = Range(value=-1.0e20,
                            low='_data_min',
                            high='_data_max',
                            enter_set=True,
                            auto_set=False,
                            desc='the lower threshold of the filter')

    # Upper threshold (this is a dynamic trait that is changed when
    # input data changes).
    upper_threshold = Range(value=1.0e20,
                            low='_data_min',
                            high='_data_max',
                            enter_set=True,
                            auto_set=False,
                            desc='the upper threshold of the filter')

    # Automatically reset the lower threshold when the upstream data
    # changes.
    auto_reset_lower = Bool(True,
                            desc='if the lower threshold is '
                            'automatically reset when upstream '
                            'data changes')

    # Automatically reset the upper threshold when the upstream data
    # changes.
    auto_reset_upper = Bool(True,
                            desc='if the upper threshold is '
                            'automatically reset when upstream '
                            'data changes')

    input_info = PipelineInfo(datasets=['any'],
                              attribute_types=['any'],
                              attributes=['any'])

    output_info = PipelineInfo(datasets=['poly_data', 'unstructured_grid'],
                               attribute_types=['any'],
                               attributes=['any'])

    # Our view.
    view = View(Group(
        Group(Item(name='filter_type'), Item(name='lower_threshold'),
              Item(name='auto_reset_lower'), Item(name='upper_threshold'),
              Item(name='auto_reset_upper')),
        Item(name='_'),
        Group(
            Item(name='threshold_filter',
                 show_label=False,
                 visible_when='object.filter_type == "cells"',
                 style='custom',
                 resizable=True)),
    ),
                resizable=True)

    ########################################
    # Private traits.

    # These traits are used to set the limits for the thresholding.
    # They store the minimum and maximum values of the input data.
    _data_min = Float(-1e20)
    _data_max = Float(1e20)

    # The threshold filter for cell based filtering
    _threshold = Instance(tvtk.Threshold, args=(), allow_none=False)

    # The threshold filter for points based filtering.
    _threshold_points = Instance(tvtk.ThresholdPoints,
                                 args=(),
                                 allow_none=False)

    # Internal data to
    _first = Bool(True)

    ######################################################################
    # `object` interface.
    ######################################################################
    def __get_pure_state__(self):
        d = super(Threshold, self).__get_pure_state__()
        # These traits are dynamically created.
        for name in ('_first', '_data_min', '_data_max'):
            d.pop(name, None)

        return d

    ######################################################################
    # `Filter` interface.
    ######################################################################
    def setup_pipeline(self):
        attrs = [
            'all_scalars', 'attribute_mode', 'component_mode',
            'selected_component'
        ]
        self._threshold.on_trait_change(self._threshold_filter_edited, attrs)

    def update_pipeline(self):
        """Override this method so that it *updates* the tvtk pipeline
        when data upstream is known to have changed.

        This method is invoked (automatically) when the input fires a
        `pipeline_changed` event.
        """
        if len(self.inputs) == 0:
            return

        # By default we set the input to the first output of the first
        # input.
        fil = self.threshold_filter
        fil.input = self.inputs[0].outputs[0]

        self._update_ranges()
        self._set_outputs([self.threshold_filter.output])

    def update_data(self):
        """Override this method to do what is necessary when upstream
        data changes.

        This method is invoked (automatically) when any of the inputs
        sends a `data_changed` event.
        """
        if len(self.inputs) == 0:
            return

        self._update_ranges()

        # Propagate the data_changed event.
        self.data_changed = True

    ######################################################################
    # Non-public interface
    ######################################################################
    def _lower_threshold_changed(self, new_value):
        fil = self.threshold_filter
        fil.threshold_between(new_value, self.upper_threshold)
        fil.update()
        self.data_changed = True

    def _upper_threshold_changed(self, new_value):
        fil = self.threshold_filter
        fil.threshold_between(self.lower_threshold, new_value)
        fil.update()
        self.data_changed = True

    def _update_ranges(self):
        """Updates the ranges of the input.
        """
        data_range = self._get_data_range()
        if len(data_range) > 0:
            dr = data_range
            if self._first:
                self._data_min, self._data_max = dr
                self.set(lower_threshold=dr[0], trait_change_notify=False)
                self.upper_threshold = dr[1]
                self._first = False
            else:
                if self.auto_reset_lower:
                    self._data_min = dr[0]
                    notify = not self.auto_reset_upper
                    self.set(lower_threshold=dr[0], trait_change_notify=notify)
                if self.auto_reset_upper:
                    self._data_max = dr[1]
                    self.upper_threshold = dr[1]

    def _get_data_range(self):
        """Returns the range of the input scalar data."""
        input = self.inputs[0].outputs[0]
        data_range = []
        ps = input.point_data.scalars
        cs = input.cell_data.scalars

        # FIXME: need to be able to handle cell and point data
        # together.
        if ps is not None:
            data_range = list(ps.range)
            if np.isnan(data_range[0]):
                data_range[0] = float(np.nanmin(ps.to_array()))
            if np.isnan(data_range[1]):
                data_range[1] = float(np.nanmax(ps.to_array()))
        elif cs is not None:
            data_range = cs.range
            if np.isnan(data_range[0]):
                data_range[0] = float(np.nanmin(cs.to_array()))
            if np.isnan(data_range[1]):
                data_range[1] = float(np.nanmax(cs.to_array()))
        return data_range

    def _auto_reset_lower_changed(self, value):
        if len(self.inputs) == 0:
            return
        if value:
            dr = self._get_data_range()
            self._data_min = dr[0]
            self.lower_threshold = dr[0]

    def _auto_reset_upper_changed(self, value):
        if len(self.inputs) == 0:
            return
        if value:
            dr = self._get_data_range()
            self._data_max = dr[1]
            self.upper_threshold = dr[1]

    def _get_threshold_filter(self):
        if self.filter_type == 'cells':
            return self._threshold
        else:
            return self._threshold_points

    def _filter_type_changed(self, value):
        if value == 'cells':
            old = self._threshold_points
            new = self._threshold
        else:
            old = self._threshold
            new = self._threshold_points
        self.trait_property_changed('threshold_filter', old, new)

    def _threshold_filter_changed(self, old, new):
        if len(self.inputs) == 0:
            return
        fil = new
        fil.input = self.inputs[0].outputs[0]
        fil.threshold_between(self.lower_threshold, self.upper_threshold)
        fil.update()
        self._set_outputs([fil.output])

    def _threshold_filter_edited(self):
        self.threshold_filter.update()
        self.data_changed = True
Example #20
0
class TCPDriver(IODriver):
    """
      TCP input driver.
  """

    name = Str('TCP Driver')
    view = View(Item(name='port', label='Port'),
                Item(name='show_debug_msgs', label='Show debug messages'),
                Item(name='buffer_size', label='Buffer size / kb'),
                Item(name='timeout', label='Timeout / s'),
                title='TCP input driver')

    _sock = socket.socket()

    port = Range(1024, 65535, 34443)
    buffer_size = Range(
        1, 4096,
        10)  # no reason not to go above 4MB but there should be some limit.
    timeout = Float(1.0)
    ip = Str('0.0.0.0')
    show_debug_msgs = Bool(False)

    is_open = Bool(False)

    def open(self):
        print "Opening (one time)"

    def listen(self):
        print "Listening..."
        self.is_open = False
        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            self._sock.bind((self.ip, self.port))
        except socket.error:
            print "Error, address probably already bound. Will try again"
            return
        self._sock.settimeout(self.timeout)  # seconds
        self._sock.listen(1)
        self._sock, (addr, _) = self._sock.accept()
        print addr, " connected!"
        self.is_open = True

    def close(self):
        print "Closing"
        self.is_open = False
        self._sock.close()

    def receive(self):
        try:
            if not self.is_open:
                self.listen()
                return None
            else:
                try:
                    (data, _) = self._sock.recvfrom(1024 * self.buffer_size)
                    if self.show_debug_msgs:
                        print "TCP driver: packet size %u bytes" % len(data)
                except socket.error:
                    return None
        except socket.timeout:
            print "Socket timed out"
            self.is_open = False
            return None
        return data

    def rebind_socket(self):
        self.close()

    @on_trait_change('port')
    def change_port(self):
        self.rebind_socket()

    @on_trait_change('address')
    def change_address(self):
        self.rebind_socket()

    @on_trait_change('timeout')
    def change_timeout(self):
        self.rebind_socket()
Example #21
0
class CameraControl(HasTraits):
    # Traits
    positions = List(CameraPosition)
    yt_scene = Instance('YTScene')
    center = Delegate('yt_scene')
    scene = Delegate('yt_scene')
    camera = Instance(tvtk.OpenGLCamera)
    reset_position = Instance(CameraPosition)
    fps = Float(25.0)
    export_filename = 'frames'
    periodic = Bool

    # UI elements
    snapshot = Button()
    play = Button()
    export_frames = Button()
    reset_path = Button()
    recenter = Button()
    save_path = Button()
    load_path = Button()
    export_path = Button()

    table_def = TableEditor(columns=[
        ObjectColumn(name='position'),
        ObjectColumn(name='focal_point'),
        ObjectColumn(name='view_up'),
        ObjectColumn(name='clipping_range'),
        ObjectColumn(name='num_steps')
    ],
                            reorderable=True,
                            deletable=True,
                            sortable=True,
                            sort_model=True,
                            show_toolbar=True,
                            selection_mode='row',
                            selected='reset_position')

    default_view = View(
        VGroup(
            HGroup(Item('camera', show_label=False),
                   Item('recenter', show_label=False),
                   label='Camera'),
            HGroup(Item('snapshot', show_label=False),
                   Item('play', show_label=False),
                   Item('export_frames', show_label=False),
                   Item('reset_path', show_label=False),
                   Item('save_path', show_label=False),
                   Item('load_path', show_label=False),
                   Item('export_path', show_label=False),
                   Item('export_filename'),
                   Item('periodic'),
                   Item('fps'),
                   label='Playback'),
            VGroup(Item('positions', show_label=False, editor=table_def),
                   label='Camera Path'),
        ),
        resizable=True,
        title="Camera Path Editor",
    )

    def _reset_position_changed(self, old, new):
        if new is None: return
        cam = self.scene.camera
        cam.position = new.position
        cam.focal_point = new.focal_point
        cam.view_up = new.view_up
        cam.clipping_range = new.clipping_range
        self.scene.render()

    def __init__(self, **traits):
        HasTraits.__init__(self, **traits)

    def take_snapshot(self):
        cam = self.scene.camera
        self.positions.append(
            CameraPosition(position=cam.position,
                           focal_point=cam.focal_point,
                           view_up=cam.view_up,
                           clipping_range=cam.clipping_range,
                           distance=cam.distance,
                           orientation_wxyz=cam.orientation_wxyz))

    def _export_path_fired(self):
        dlg = pyface.FileDialog(
            action='save as',
            wildcard="*.cpath",
        )
        if dlg.open() == pyface.OK:
            print "Saving:", dlg.path
            self.export_camera_path(dlg.path)

    def export_camera_path(self, fn):
        to_dump = dict(positions=[],
                       focal_points=[],
                       view_ups=[],
                       clipping_ranges=[],
                       distances=[],
                       orientation_wxyzs=[])

        def _write(cam):
            to_dump['positions'].append(cam.position)
            to_dump['focal_points'].append(cam.focal_point)
            to_dump['view_ups'].append(cam.view_up)
            to_dump['clipping_ranges'].append(cam.clipping_range)
            to_dump['distances'].append(cam.distance)
            to_dump['orientation_wxyzs'].append(cam.orientation_wxyz)

        self.step_through(0.0, callback=_write)
        pickle.dump(to_dump, open(fn, "wb"))

    def _save_path_fired(self):
        dlg = pyface.FileDialog(
            action='save as',
            wildcard="*.cpath",
        )
        if dlg.open() == pyface.OK:
            print "Saving:", dlg.path
            self.dump_camera_path(dlg.path)

    def dump_camera_path(self, fn):
        to_dump = dict(positions=[],
                       focal_points=[],
                       view_ups=[],
                       clipping_ranges=[],
                       distances=[],
                       orientation_wxyzs=[],
                       num_stepss=[])
        for p in self.positions:
            to_dump['positions'].append(p.position)
            to_dump['focal_points'].append(p.focal_point)
            to_dump['view_ups'].append(p.view_up)
            to_dump['clipping_ranges'].append(p.clipping_range)
            to_dump['distances'].append(p.distance)
            to_dump['num_stepss'].append(p.num_steps)  # stupid s
            to_dump['orientation_wxyzs'].append(p.orientation_wxyz)
        pickle.dump(to_dump, open(fn, "wb"))

    def _load_path_fired(self):
        dlg = pyface.FileDialog(
            action='open',
            wildcard="*.cpath",
        )
        if dlg.open() == pyface.OK:
            print "Loading:", dlg.path
            self.load_camera_path(dlg.path)

    def load_camera_path(self, fn):
        to_use = pickle.load(open(fn, "rb"))
        self.positions = []
        for i in range(len(to_use['positions'])):
            dd = {}
            for kw in to_use:
                # Strip the s
                dd[kw[:-1]] = to_use[kw][i]
            self.positions.append(CameraPosition(**dd))

    def _recenter_fired(self):
        self.camera.focal_point = self.center
        self.scene.render()

    def _snapshot_fired(self):
        self.take_snapshot()

    def _play_fired(self):
        self.step_through()

    def _export_frames_fired(self):
        self.step_through(save_frames=True)

    def _reset_path_fired(self):
        self.positions = []

    def step_through(self, pause=1.0, callback=None, save_frames=False):
        cam = self.scene.camera
        frame_counter = 0
        if self.periodic:
            cyclic_pos = self.positions + [self.positions[0]]
        else:
            cyclic_pos = self.positions
        for i in range(len(cyclic_pos) - 1):
            pos1 = cyclic_pos[i]
            pos2 = cyclic_pos[i + 1]
            r = pos1.num_steps
            for p in range(pos1.num_steps):
                po = _interpolate(pos1.position, pos2.position, p, r)
                fp = _interpolate(pos1.focal_point, pos2.focal_point, p, r)
                vu = _interpolate(pos1.view_up, pos2.view_up, p, r)
                cr = _interpolate(pos1.clipping_range, pos2.clipping_range, p,
                                  r)
                _set_cpos(cam, po, fp, vu, cr)
                self.scene.render()
                if callback is not None: callback(cam)
                if save_frames:
                    self.scene.save("%s_%0.5d.png" %
                                    (self.export_filename, frame_counter))
                else:
                    time.sleep(pause * 1.0 / self.fps)
                frame_counter += 1
class Generator(HasTraits):
    meanIntensity = Float(500)
    meanDuration = Float(3)
    backgroundIntensity = Float(300)
    meanEventNumber = Float(2)
    scaleFactor = Float(
        2)  # Note: we don't expose the scale factor in the view
    meanTime = Float(2000)
    mode = Enum(['STORM', 'PAINT'])

    sources = List([WormlikeSource(), ImageSource(), FileSource()])

    source = Instance(PointSource)

    helpInfo = {
        'source':
        '''
Select the type of point source to generate points.
A wormlike source, an image based source and a file based source are supported.
''',
        'meanIntensity':
        '''
This parameter specifies the mean number of photons in an event.
Typically values are in the range from 100 to several 10000.
''',
        'meanDuration':
        '''
The mean duration of events which is specified in units of frames.
''',
        'meanTime':
        '''
This parameter, the mean time of the series in frame units, is the average time at which you expect to get events
(i.e. the value of np.mean(pipeline['t']) for the simulated set of events). Since STORM mode draws event times from
an exponential distribution, PAINT from a uniform one, it can also be related to the resulting apparent series duration,
which may be more familar to experimentally minded users. For PAINT mode it works out as half the duration of the series,
for STORM simulation mode the relationship is a little more complex,
you can work it out from the decay time of an exponential distribution.
''',
        'meanEventNumber':
        '''
This parameter specifies the mean number of times an event occurs at a single marker location.
''',
        'backgroundIntensity':
        '''
The background intensity per pixel in units of photons, typically in the range from a few tens to hundreds of photons.
''',
        'mode':
        '''
With the simulation mode you can choose between STORM or PAINT mode.
This parameter effects how event rate changes with time (it stays constant in PAINT mode).
''',
        'scaleFactor':
        '''
This parameter is related to the size of the PSF for purposes of thresholding
(used in combination with the background intensity, which is per pixel).
There should be no need to modify this from the default and it is accordingly not exposed in the view.
 
''',
    }

    def helpStr(self, name):
        def cleanupHelpStr(str):
            return str.strip().replace('\n', ' ').replace('\r', '')

        return cleanupHelpStr(self.helpInfo[name])

    def default_traits_view(self):
        from traitsui.api import View, Item, InstanceEditor

        traits_view = View(
            Item(
                'source',
                label='Point source',
                editor=InstanceEditor(name='sources', editable=True),
                help=self.helpStr('source'),
            ),
            Item('_'),
            Item('meanIntensity',
                 tooltip='mean photon number of events',
                 help=self.helpStr('meanIntensity')),
            Item('meanDuration',
                 tooltip='mean duration of events in units of frames',
                 help=self.helpStr('meanDuration')),
            Item(
                'meanEventNumber',
                tooltip=
                'mean number of times events occurs at a single marker location',
                help=self.helpStr('meanEventNumber')),
            Item(
                'meanTime',
                tooltip=
                'mean time of the series, roughly related to series duration, in frame units',
                help=self.helpStr('meanTime')),
            Item('_'),
            Item('backgroundIntensity',
                 tooltip='background intensity in units of photons',
                 help=self.helpStr('backgroundIntensity')),
            Item('_'),
            Item(
                'mode',
                tooltip=
                'STORM or PAINT mode, effects how event rate changes with time',
                help=self.helpStr('mode')),
            buttons=['OK', 'Help'])

        return traits_view

    def __init__(self, visFr=None):
        self.visFr = visFr
        self.source = self.sources[0]

        if visFr:
            visFr.AddMenuItem('Extras>Synthetic Data', "Configure",
                              self.OnConfigure)
            visFr.AddMenuItem('Extras>Synthetic Data',
                              'Generate fluorophore positions and events',
                              self.OnGenPoints)
            visFr.AddMenuItem('Extras>Synthetic Data', 'Generate events',
                              self.OnGenEvents)

    def OnConfigure(self, event):
        self.source.refresh_choices()
        self.edit_traits()

    def OnGenPoints(self, event):
        self.xp, self.yp, self.zp = self.source.getPoints()
        self.OnGenEvents(None)

    def OnGenEvents(self, event):
        from PYME.simulation import locify
        #from PYME.Acquire.Hardware.Simulator import wormlike2
        from PYME.IO import tabular
        from PYME.IO.image import ImageBounds
        # import pylab
        import matplotlib.pyplot as plt

        #wc = wormlike2.wormlikeChain(100)

        pipeline = self.visFr.pipeline
        pipeline.filename = 'Simulation'

        plt.figure()
        plt.plot(self.xp, self.yp, 'x')  #, lw=2)
        if isinstance(self.source, WormlikeSource):
            plt.plot(self.xp, self.yp, lw=2)

        if self.mode == 'STORM':
            res = locify.eventify(self.xp,
                                  self.yp,
                                  self.meanIntensity,
                                  self.meanDuration,
                                  self.backgroundIntensity,
                                  self.meanEventNumber,
                                  self.scaleFactor,
                                  self.meanTime,
                                  z=self.zp)
        else:
            res = locify.eventify2(self.xp,
                                   self.yp,
                                   self.meanIntensity,
                                   self.meanDuration,
                                   self.backgroundIntensity,
                                   self.meanEventNumber,
                                   self.scaleFactor,
                                   self.meanTime,
                                   z=self.zp)

        plt.plot(res['fitResults']['x0'], res['fitResults']['y0'], '+')

        ds = tabular.MappingFilter(tabular.FitResultsSource(res))

        if isinstance(self.source, ImageSource):
            pipeline.imageBounds = image.openImages[
                self.source.image].imgBounds
        else:
            pipeline.imageBounds = ImageBounds.estimateFromSource(ds)

        pipeline.addDataSource('Generated Points', ds)
        pipeline.selectDataSource('Generated Points')

        from PYME.IO.MetaDataHandler import NestedClassMDHandler
        pipeline.mdh = NestedClassMDHandler()
        pipeline.mdh['Camera.ElectronsPerCount'] = 1
        pipeline.mdh['Camera.TrueEMGain'] = 1
        pipeline.mdh['Camera.CycleTime'] = 1
        pipeline.mdh['voxelsize.x'] = .110

        try:
            pipeline.filterKeys.pop('sig')
        except:
            pass

        pipeline.Rebuild()
        if len(self.visFr.layers) < 1:
            self.visFr.add_pointcloud_layer(
            )  #TODO - move this logic so that layer added automatically when datasource is added?
        #self.visFr.CreateFoldPanel()
        self.visFr.SetFit()
Example #23
0
class MainWindow(SplitApplicationWindow):
    """ The main application window. """

    #### 'SplitApplicationWindow' interface ###################################

    # The ratio of the size of the left/top pane to the right/bottom pane.
    ratio = Float(0.3)

    # The direction in which the panel is split.
    direction = Str('vertical')

    ###########################################################################
    # 'object' interface.
    ###########################################################################

    def __init__(self, **traits):
        """ Creates a new window. """

        # Base class constructor.
        super(MainWindow, self).__init__(**traits)

        # Create the window's menu, tool and status bars.
        self._create_action_bars()

        return

    ###########################################################################
    # Protected 'SplitApplicationWindow' interface.
    ###########################################################################

    def _create_lhs(self, parent):
        """ Creates the left hand side or top depending on the style. """

        return self._create_file_tree(parent, os.path.abspath(os.curdir))

    def _create_rhs(self, parent):
        """ Creates the panel containing the selected preference page. """

        self._rhs = SplitPanel(parent=parent,
                               lhs=self._create_file_table,
                               rhs=self._create_python_shell,
                               direction='horizontal')

        return self._rhs.control

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _create_action_bars(self):
        """ Creates the window's menu, tool and status bars. """

        # Common actions.
        highest = Action(name='Highest', style='radio')
        higher = Action(name='Higher', style='radio', checked=True)
        lower = Action(name='Lower', style='radio')
        lowest = Action(name='Lowest', style='radio')

        self._actions = [highest, higher, lower, lowest]

        # Menu bar.
        self.menu_bar_manager = MenuBarManager(
            MenuManager(
                ExampleAction(name='Foogle'),
                Separator(),
                highest,
                higher,
                lower,
                lowest,
                Separator(),
                Action(name='E&xit', on_perform=self.close),
                name='&File',
            ))

        # Tool bar.
        self.tool_bar_manager = ToolBarManager(ExampleAction(name='Foo'),
                                               Separator(),
                                               ExampleAction(name='Bar'),
                                               Separator(),
                                               ExampleAction(name='Baz'),
                                               Separator(), highest, higher,
                                               lower, lowest)

        # Status bar.
        self.status_bar_manager = StatusBarManager()

        return

    def _create_file_tree(self, parent, dirname):
        """ Creates the file tree. """

        self._tree_viewer = tree_viewer = FileTreeViewer(
            parent,
            input=os.path.abspath(os.curdir),
            filters=[AllowOnlyFolders()])

        tree_viewer.on_trait_change(self._on_selection_changed, 'selection')

        return tree_viewer.control

    def _create_file_table(self, parent):
        """ Creates the file table. """

        self._table_viewer = table_viewer = FileTableViewer(
            parent, sorter=FileSorter(), odd_row_background="white")

        return table_viewer.control

    def _create_python_shell(self, parent):
        """ Creates the Python shell. """

        self._python_shell = python_shell = PythonShell(parent)
        python_shell.bind('widget', self._tree_viewer)
        python_shell.bind('w', self._tree_viewer)
        python_shell.bind('window', self)
        python_shell.bind('actions', self._actions)

        return python_shell.control

    #### Trait event handlers #################################################

    def _on_selection_changed(self, selection):
        """ Called when the selection in the tree is changed. """

        if len(selection) > 0:
            self._table_viewer.input = selection[0]

        return
Example #24
0
class PointDisplaySettings(HasTraits):
    pointSize = Float(5.0)
    colourDataKey = CStr('t')
    alpha = Float(1.0)
Example #25
0
class Parent(HasTraits):
    first_name = Str
    family_name = ''
    favorite_first_name = Str
    child_allowance = Float(1.00)
Example #26
0
class Path(Artist):
    """
    An interface class between the higher level artists and the path
    primitive that needs to talk to the renderers
    """
    _path = traits.Instance(PathPrimitive, ())
    antialiased = mtraits.AntiAliased()
    color = mtraits.Color('blue')
    facecolor = mtraits.Color('yellow')
    linestyle = mtraits.LineStyle('-')
    linewidth = mtraits.LineWidth(1.0)
    model = mtraits.Model
    pathdata = traits.Tuple(Array('b'), VertexArray)
    sequence = 'paths'
    zorder = Float(1.0)

    # why have an extra layer separating the PathPrimitive from the
    # Path artist?  The reasons are severalfold, but it is still not
    # clear if this is the better solution.  Doing it this way enables
    # the backends to create their own derived primitves (eg
    # RendererAgg creates PathPrimitiveAgg, and in that class sets up
    # trait listeners to create agg colors and agg paths when the
    # PathPrimitive traits change.  Another reason is that it allows
    # us to handle nonlinear transformation (the "model") at the top
    # layer w/o making the backends understand them.  The current
    # design is create a mapping between backend primitives and
    # primitive artists (Path, Text, Image, etc...) and all of the
    # higher level Artists (Line, Polygon, Axis) will use the
    # primitive artitsts. So only a few artists will need to know how
    # to talk to the backend.  The alternative is to make the backends
    # track and understand the primitive artists themselves.

    def __init__(self):
        """
        The model is a function taking Nx2->Nx2.  This is where the
        nonlinear transformation can be used
        """
        Artist.__init__(self)
        self._pathid = primitiveID()

    def _pathdata_default(self):
        return (npy.array([0, 0], dtype=npy.uint8),
                npy.array([[0, 0], [0, 0]], npy.float_))

    def _update_path(self):
        'sync the Path traits with the path primitive'
        self.sync_trait('linewidth', self._path, mutual=False)
        self.sync_trait('color', self._path, mutual=False)
        self.sync_trait('facecolor', self._path, mutual=False)
        self.sync_trait('antialiased', self._path, mutual=False)

        # sync up the path affine
        self._path.affine.follow(self.affine.vec6)
        self.affine.on_trait_change(self._path.affine.follow, 'vec6')
        self._update_pathdata()

    def _update_pathdata(self):
        #print 'PATH: update pathdata'

        codes, xy = self.pathdata

        #print '    PATH: shapes', codes.shape, xy.shape
        if self.model is not None:
            xy = self.model(xy)

        pathdata = codes, xy

        self._path.pathdata = pathdata

    def draw(self):
        if self.renderer is None or not self.visible: return
        Artist.draw(self)
        self.renderer.render_path(self._pathid)

    def _renderer_changed(self, old, new):
        if old is not None:
            del old.pathd[self._pathid]

        if new is None: return

        #print 'PATH renderer_changed; updating'
        self._path = renderer.new_path_primitive()
        new.pathd[self._pathid] = self._path
        self._update_path()

    def _model_changed(self, old, new):
        self._update_pathdata()

    def _pathdata_changed(self, old, new):
        #print 'PATH: pathdata changed'
        self._update_pathdata()
Example #27
0
class Edge(HasTraits):
    """ Defines a graph edge. """

    #--------------------------------------------------------------------------
    #  Trait definitions:
    #--------------------------------------------------------------------------

    # Tail/from/source/start node.
    tail_node = Instance(Node, allow_none=False)

    # Head/to/target/end node.
    head_node = Instance(Node, allow_none=False)

    # String identifier (TreeNode label).
    name = Property(
        Str,
        depends_on=["tail_node", "tail_node.ID", "head_node", "head_node.ID"])

    # Connection string used in string output.
    conn = Enum("->", "--")

    # Nodes from which the tail and head nodes may be selected.
    _nodes = List(Instance(Node))  # GUI specific.

    #--------------------------------------------------------------------------
    #  Xdot trait definitions:
    #--------------------------------------------------------------------------

    # For a given graph object, one will typically a draw directive before the
    # label directive. For example, for a node, one would first use the
    # commands in _draw_ followed by the commands in _ldraw_.
    _draw_ = Str(desc="xdot drawing directive", label="draw")
    _ldraw_ = Str(desc="xdot label drawing directive", label="ldraw")

    _hdraw_ = Str(desc="edge head arrowhead drawing directive.", label="hdraw")
    _tdraw_ = Str(desc="edge tail arrowhead drawing directive.", label="tdraw")
    _hldraw_ = Str(desc="edge head label drawing directive.", label="hldraw")
    _tldraw_ = Str(desc="edge tail label drawing directive.", label="tldraw")

    #--------------------------------------------------------------------------
    #  Enable trait definitions:
    #--------------------------------------------------------------------------

    # Container of drawing components, typically the edge spline.
    drawing = Instance(Container)

    # Container of label components.
    label_drawing = Instance(Container)

    # Container of head arrow components.
    arrowhead_drawing = Instance(Container)

    # Container of tail arrow components.
    arrowtail_drawing = Instance(Container)

    # Container of head arrow label components.
    arrowhead_label_drawing = Instance(Container)

    # Container of tail arrow label components.
    arrowtail_label_drawing = Instance(Container)

    # Container for the drawing, label, arrow and arrow label components.
    component = Instance(Container, desc="container of graph components.")

    # A view into a sub-region of the canvas.
    vp = Instance(Viewport, desc="a view of a sub-region of the canvas")

    # Use Graphviz to arrange all graph components.
    arrange = Button("Arrange All")

    #--------------------------------------------------------------------------
    #  Dot trait definitions:
    #--------------------------------------------------------------------------

    # Style of arrowhead on the head node of an edge.
    # See also the <html:a rel="attr">dir</html:a> attribute,
    # and the <html:a rel="note">undirected</html:a> note.
    arrowhead = arrow_trait

    # Multiplicative scale factor for arrowheads.
    arrowsize = Float(1.0,
                      desc="multiplicative scale factor for arrowheads",
                      label="Arrow size",
                      graphviz=True)

    # Style of arrowhead on the tail node of an edge.
    # See also the <html:a rel="attr">dir</html:a> attribute,
    # and the <html:a rel="note">undirected</html:a> note.
    arrowtail = arrow_trait

    # Basic drawing color for graphics, not text. For the latter, use the
    # <html:a rel="attr">fontcolor</html:a> attribute.
    #
    # For edges, the value
    # can either be a single <html:a rel="type">color</html:a> or a <html:a rel="type">colorList</html:a>.
    # In the latter case, the edge is drawn using parallel splines or lines,
    # one for each color in the list, in the order given.
    # The head arrow, if any, is drawn using the first color in the list,
    # and the tail arrow, if any, the second color. This supports the common
    # case of drawing opposing edges, but using parallel splines instead of
    # separately routed multiedges.
    color = color_trait

    # This attribute specifies a color scheme namespace. If defined, it specifies
    # the context for interpreting color names. In particular, if a
    # <html:a rel="type">color</html:a> value has form <html:code>xxx</html:code> or <html:code>//xxx</html:code>,
    # then the color <html:code>xxx</html:code> will be evaluated according to the current color scheme.
    # If no color scheme is set, the standard X11 naming is used.
    # For example, if <html:code>colorscheme=bugn9</html:code>, then <html:code>color=7</html:code>
    # is interpreted as <html:code>/bugn9/7</html:code>.
    colorscheme = color_scheme_trait

    # Comments are inserted into output. Device-dependent.
    comment = comment_trait

    # If <html:span class="val">false</html:span>, the edge is not used in
    # ranking the nodes.
    constraint = Bool(True,
                      desc="if edge is used in ranking the nodes",
                      graphviz=True)

    # If <html:span class="val">true</html:span>, attach edge label to edge by a 2-segment
    # polyline, underlining the label, then going to the closest point of spline.
    decorate = Bool(
        False,
        desc="to attach edge label to edge by a 2-segment "
        "polyline, underlining the label, then going to the closest point of "
        "spline",
        graphviz=True)

    # Set edge type for drawing arrowheads. This indicates which ends of the
    # edge should be decorated with an arrowhead. The actual style of the
    # arrowhead can be specified using the <html:a rel="attr">arrowhead</html:a>
    # and <html:a rel="attr">arrowtail</html:a> attributes.
    # See <html:a rel="note">undirected</html:a>.
    dir = Enum("forward",
               "back",
               "both",
               "none",
               label="Direction",
               desc="edge type for drawing arrowheads",
               graphviz=True)

    # Synonym for <html:a rel="attr">edgeURL</html:a>.
    #    edgehref = Alias("edgeURL", desc="synonym for edgeURL")
    edgehref = Synced(sync_to="edgeURL", graphviz=True)

    # If the edge has a URL or edgeURL  attribute, this attribute determines
    # which window of the browser is used for the URL attached to the non-label
    # part of the edge. Setting it to "_graphviz" will open a new window if it
    # doesn't already exist, or reuse it if it does. If undefined, the value of
    # the target is used.
    edgetarget = Str("",
                     desc="which window of the browser is used for the "
                     "URL attached to the non-label part of the edge",
                     label="Edge target",
                     graphviz=True)

    # Tooltip annotation attached to the non-label part of an edge.
    # This is used only if the edge has a <html:a rel="attr">URL</html:a>
    # or <html:a rel="attr">edgeURL</html:a> attribute.
    edgetooltip = Str("",
                      desc="annotation attached to the non-label part of "
                      "an edge",
                      label="Edge tooltip",
                      graphviz=True)
    #    edgetooltip = EscString

    # If <html:a rel="attr">edgeURL</html:a> is defined, this is the link used for the non-label
    # parts of an edge. This value overrides any <html:a rel="attr">URL</html:a>
    # defined for the edge.
    # Also, this value is used near the head or tail node unless overridden
    # by a <html:a rel="attr">headURL</html:a> or <html:a rel="attr">tailURL</html:a> value,
    # respectively.
    # See <html:a rel="note">undirected</html:a>.
    edgeURL = Str("",
                  desc="link used for the non-label parts of an edge",
                  label="Edge URL",
                  graphviz=True)  #LabelStr

    # Color used for text.
    fontcolor = fontcolor_trait

    # Font used for text. This very much depends on the output format and, for
    # non-bitmap output such as PostScript or SVG, the availability of the font
    # when the graph is displayed or printed. As such, it is best to rely on
    # font faces that are generally available, such as Times-Roman, Helvetica or
    # Courier.
    #
    # If Graphviz was built using the
    # <html:a href="http://pdx.freedesktop.org/~fontconfig/fontconfig-user.html">fontconfig library</html:a>, the latter library
    # will be used to search for the font. However, if the <html:a rel="attr">fontname</html:a> string
    # contains a slash character "/", it is treated as a pathname for the font
    # file, though font lookup will append the usual font suffixes.
    #
    # If Graphviz does not use fontconfig, <html:a rel="attr">fontname</html:a> will be
    # considered the name of a Type 1 or True Type font file.
    # If you specify <html:code>fontname=schlbk</html:code>, the tool will look for a
    # file named  <html:code>schlbk.ttf</html:code> or <html:code>schlbk.pfa</html:code> or <html:code>schlbk.pfb</html:code>
    # in one of the directories specified by
    # the <html:a rel="attr">fontpath</html:a> attribute.
    # The lookup does support various aliases for the common fonts.
    fontname = fontname_trait

    # Font size, in <html:a rel="note">points</html:a>, used for text.
    fontsize = fontsize_trait

    # If <html:span class="val">true</html:span>, the head of an edge is clipped to the boundary of the head node;
    # otherwise, the end of the edge goes to the center of the node, or the
    # center of a port, if applicable.
    headclip = Bool(True,
                    desc="head of an edge to be clipped to the boundary "
                    "of the head node",
                    label="Head clip",
                    graphviz=True)

    # Synonym for <html:a rel="attr">headURL</html:a>.
    headhref = Alias("headURL", desc="synonym for headURL", graphviz=True)

    # Text label to be placed near head of edge.
    # See <html:a rel="note">undirected</html:a>.
    headlabel = Str("",
                    desc="text label to be placed near head of edge",
                    label="Head label",
                    graphviz=True)

    headport = port_pos_trait

    # If the edge has a headURL, this attribute determines which window of the
    # browser is used for the URL. Setting it to "_graphviz" will open a new
    # window if it doesn't already exist, or reuse it if it does. If undefined,
    # the value of the target is used.
    headtarget = Str(desc="which window of the browser is used for the URL",
                     label="Head target",
                     graphviz=True)

    # Tooltip annotation attached to the head of an edge. This is used only
    # if the edge has a <html:a rel="attr">headURL</html:a> attribute.
    headtooltip = Str("",
                      desc="tooltip annotation attached to the head of an "
                      "edge",
                      label="Head tooltip",
                      graphviz=True)

    # If <html:a rel="attr">headURL</html:a> is defined, it is
    # output as part of the head label of the edge.
    # Also, this value is used near the head node, overriding any
    # <html:a rel="attr">URL</html:a> value.
    # See <html:a rel="note">undirected</html:a>.
    headURL = Str("",
                  desc="output as part of the head label of the edge",
                  label="Head URL",
                  graphviz=True)

    # Synonym for <html:a rel="attr">URL</html:a>.
    href = Alias("URL", desc="synonym for URL", graphviz=True)

    # Text label attached to objects.
    # If a node's <html:a rel="attr">shape</html:a> is record, then the label can
    # have a <html:a href="http://www.graphviz.org/doc/info/shapes.html#record">special format</html:a>
    # which describes the record layout.
    label = label_trait

    # This, along with <html:a rel="attr">labeldistance</html:a>, determine
    # where the
    # headlabel (taillabel) are placed with respect to the head (tail)
    # in polar coordinates. The origin in the coordinate system is
    # the point where the edge touches the node. The ray of 0 degrees
    # goes from the origin back along the edge, parallel to the edge
    # at the origin.
    #
    # The angle, in degrees, specifies the rotation from the 0 degree ray,
    # with positive angles moving counterclockwise and negative angles
    # moving clockwise.
    labelangle = Float(
        -25.0,
        desc=", along with labeldistance, where the "
        "headlabel (taillabel) are placed with respect to the head (tail)",
        label="Label angle",
        graphviz=True)

    # Multiplicative scaling factor adjusting the distance that
    # the headlabel (taillabel) is from the head (tail) node.
    # The default distance is 10 points. See <html:a rel="attr">labelangle</html:a>
    # for more details.
    labeldistance = Float(
        1.0,
        desc="multiplicative scaling factor adjusting "
        "the distance that the headlabel (taillabel) is from the head (tail) "
        "node",
        label="Label distance",
        graphviz=True)

    # If true, allows edge labels to be less constrained in position. In
    # particular, it may appear on top of other edges.
    labelfloat = Bool(False,
                      desc="edge labels to be less constrained in "
                      "position",
                      label="Label float",
                      graphviz=True)

    # Color used for headlabel and taillabel.
    # If not set, defaults to edge's fontcolor.
    labelfontcolor = Color("black",
                           desc="color used for headlabel and "
                           "taillabel",
                           label="Label font color",
                           graphviz=True)

    # Font used for headlabel and taillabel.
    # If not set, defaults to edge's fontname.
    labelfontname = Font("Times-Roman",
                         desc="Font used for headlabel and "
                         "taillabel",
                         label="Label font name",
                         graphviz=True)

    # Font size, in <html:a rel="note">points</html:a>, used for headlabel and taillabel.
    # If not set, defaults to edge's fontsize.
    labelfontsize = Float(14.0,
                          desc="Font size, in points, used for "
                          "headlabel and taillabel",
                          label="label_font_size",
                          graphviz=True)

    # Synonym for <html:a rel="attr">labelURL</html:a>.
    labelhref = Alias("labelURL", desc="synonym for labelURL", graphviz=True)

    # If the edge has a URL or labelURL  attribute, this attribute determines
    # which window of the browser is used for the URL attached to the label.
    # Setting it to "_graphviz" will open a new window if it doesn't already
    # exist, or reuse it if it does. If undefined, the value of the target is
    # used.
    labeltarget = Str("",
                      desc="which window of the browser is used for the "
                      "URL attached to the label",
                      label="Label target",
                      graphviz=True)

    # Tooltip annotation attached to label of an edge.
    # This is used only if the edge has a <html:a rel="attr">URL</html:a>
    # or <html:a rel="attr">labelURL</html:a> attribute.
    labeltooltip = Str("",
                       desc="tooltip annotation attached to label of an "
                       "edge",
                       label="Label tooltip",
                       graphviz=True)

    # If <html:a rel="attr">labelURL</html:a> is defined, this is the link used for the label
    # of an edge. This value overrides any <html:a rel="attr">URL</html:a>
    # defined for the edge.
    labelURL = Str(desc="link used for the label of an edge", graphviz=True)

    # Specifies layers in which the node or edge is present.
    layer = layer_trait

    # Preferred edge length, in inches.
    len = Float(1.0, desc="preferred edge length, in inches",
                graphviz=True)  #0.3(fdp)

    # Logical head of an edge. When compound is true, if lhead is defined and
    # is the name of a cluster containing the real head, the edge is clipped to
    # the boundary of the cluster.
    lhead = Str(desc="Logical head of an edge", graphviz=True)

    # Label position, in points. The position indicates the center of the label.
    lp = point_trait

    # Logical tail of an edge. When compound is true, if ltail is defined and
    # is the name of a cluster containing the real tail, the edge is clipped to
    # the boundary of the cluster.
    ltail = Str(desc="logical tail of an edge", graphviz=True)

    # Minimum edge length (rank difference between head and tail).
    minlen = Int(1, desc="minimum edge length", graphviz=True)

    # By default, the justification of multi-line labels is done within the
    # largest context that makes sense. Thus, in the label of a polygonal node,
    # a left-justified line will align with the left side of the node (shifted
    # by the prescribed margin). In record nodes, left-justified line will line
    # up with the left side of the enclosing column of fields. If nojustify is
    # "true", multi-line labels will be justified in the context of itself. For
    # example, if the attribute is set, the first label line is long, and the
    # second is shorter and left-justified, the second will align with the
    # left-most character in the first line, regardless of how large the node
    # might be.
    nojustify = nojustify_trait

    # Position of node, or spline control points.
    # For nodes, the position indicates the center of the node.
    # On output, the coordinates are in <html:a href="#points">points</html:a>.
    #
    # In neato and fdp, pos can be used to set the initial position of a node.
    # By default, the coordinates are assumed to be in inches. However, the
    # <html:a href="http://www.graphviz.org/doc/info/command.html#d:s">-s</html:a> command line flag can be used to specify
    # different units.
    #
    # When the <html:a href="http://www.graphviz.org/doc/info/command.html#d:n">-n</html:a> command line flag is used with
    # neato, it is assumed the positions have been set by one of the layout
    # programs, and are therefore in points. Thus, <html:code>neato -n</html:code> can accept
    # input correctly without requiring a <html:code>-s</html:code> flag and, in fact,
    # ignores any such flag.
    pos = List(Tuple(Float, Float), desc="spline control points")

    # Edges with the same head and the same <html:a rel="attr">samehead</html:a> value are aimed
    # at the same point on the head.
    # See <html:a rel="note">undirected</html:a>.
    samehead = Str("",
                   desc="dges with the same head and the same samehead "
                   "value are aimed at the same point on the head",
                   graphviz=True)

    # Edges with the same tail and the same <html:a rel="attr">sametail</html:a> value are aimed
    # at the same point on the tail.
    # See <html:a rel="note">undirected</html:a>.
    sametail = Str("",
                   desc="edges with the same tail and the same sametail "
                   "value are aimed at the same point on the tail",
                   graphviz=True)

    # Print guide boxes in PostScript at the beginning of
    # routesplines if 1, or at the end if 2. (Debugging)
    showboxes = showboxes_trait

    # Set style for node or edge. For cluster subgraph, if "filled", the
    # cluster box's background is filled.
    style = ListStr(desc="style for node or edge", graphviz=True)

    # If <html:span class="val">true</html:span>, the tail of an edge is clipped to the boundary of the tail node;
    # otherwise, the end of the edge goes to the center of the node, or the
    # center of a port, if applicable.
    tailclip = Bool(True,
                    desc="tail of an edge to be clipped to the boundary "
                    "of the tail node",
                    graphviz=True)

    # Synonym for <html:a rel="attr">tailURL</html:a>.
    tailhref = Alias("tailURL", desc="synonym for tailURL", graphviz=True)

    # Text label to be placed near tail of edge.
    # See <html:a rel="note">undirected</html:a>.
    taillabel = Str(desc="text label to be placed near tail of edge",
                    graphviz=True)

    # Indicates where on the tail node to attach the tail of the edge.
    tailport = port_pos_trait

    # If the edge has a tailURL, this attribute determines which window of the
    # browser is used for the URL. Setting it to "_graphviz" will open a new
    # window if it doesn't already exist, or reuse it if it does. If undefined,
    # the value of the target is used.
    tailtarget = Str(desc="which window of the browser is used for the URL",
                     graphviz=True)

    # Tooltip annotation attached to the tail of an edge. This is used only
    # if the edge has a <html:a rel="attr">tailURL</html:a> attribute.
    tailtooltip = Str("",
                      desc="tooltip annotation attached to the tail of an "
                      "edge",
                      label="Tail tooltip",
                      graphviz=True)

    # If <html:a rel="attr">tailURL</html:a> is defined, it is
    # output as part of the tail label of the edge.
    # Also, this value is used near the tail node, overriding any
    # <html:a rel="attr">URL</html:a> value.
    # See <html:a rel="note">undirected</html:a>.
    tailURL = Str("",
                  desc="output as part of the tail label of the edge",
                  label="Tail URL",
                  graphviz=True)

    # If the object has a URL, this attribute determines which window
    # of the browser is used for the URL.
    # See <html:a href="http://www.w3.org/TR/html401/present/frames.html#adef-target">W3C documentation</html:a>.
    target = target_trait

    # Tooltip annotation attached to the node or edge. If unset, Graphviz
    # will use the object's <html:a rel="attr">label</html:a> if defined.
    # Note that if the label is a record specification or an HTML-like
    # label, the resulting tooltip may be unhelpful. In this case, if
    # tooltips will be generated, the user should set a <html:tt>tooltip</html:tt>
    # attribute explicitly.
    tooltip = tooltip_trait

    # Hyperlinks incorporated into device-dependent output.
    # At present, used in ps2, cmap, i*map and svg formats.
    # For all these formats, URLs can be attached to nodes, edges and
    # clusters. URL attributes can also be attached to the root graph in ps2,
    # cmap and i*map formats. This serves as the base URL for relative URLs in the
    # former, and as the default image map file in the latter.
    #
    # For svg, cmapx and imap output, the active area for a node is its
    # visible image.
    # For example, an unfilled node with no drawn boundary will only be active on its label.
    # For other output, the active area is its bounding box.
    # The active area for a cluster is its bounding box.
    # For edges, the active areas are small circles where the edge contacts its head
    # and tail nodes. In addition, for svg, cmapx and imap, the active area
    # includes a thin polygon approximating the edge. The circles may
    # overlap the related node, and the edge URL dominates.
    # If the edge has a label, this will also be active.
    # Finally, if the edge has a head or tail label, this will also be active.
    #
    # Note that, for edges, the attributes <html:a rel="attr">headURL</html:a>,
    # <html:a rel="attr">tailURL</html:a>, <html:a rel="attr">labelURL</html:a> and
    # <html:a rel="attr">edgeURL</html:a> allow control of various parts of an
    # edge. Also note that, if active areas of two edges overlap, it is unspecified
    # which area dominates.
    URL = url_trait

    # Weight of edge. In dot, the heavier the weight, the shorter, straighter
    # and more vertical the edge is.
    weight = Float(1.0, desc="weight of edge", graphviz=True)

    #--------------------------------------------------------------------------
    #  Views:
    #--------------------------------------------------------------------------

    traits_view = View(VGroup(
        Group(
            Item(name="vp",
                 editor=ComponentEditor(height=100),
                 show_label=False,
                 id=".component"), Item("arrange", show_label=False)),
        Tabbed(
            Group(Item(name="tail_node",
                       editor=InstanceEditor(name="_nodes", editable=False)),
                  Item(name="head_node",
                       editor=InstanceEditor(name="_nodes", editable=False)), [
                           "style", "layer", "color", "colorscheme", "dir",
                           "arrowsize", "constraint", "decorate", "showboxes",
                           "tooltip", "edgetooltip", "edgetarget", "target",
                           "comment"
                       ],
                  label="Edge"),
            Group([
                "label", "fontname", "fontsize", "fontcolor", "nojustify",
                "labeltarget", "labelfloat", "labelfontsize", "labeltooltip",
                "labelangle", "lp", "labelURL", "labelfontname",
                "labeldistance", "labelfontcolor", "labelhref"
            ],
                  label="Label"),
            Group(["minlen", "weight", "len", "pos"], label="Dimension"),
            Group([
                "arrowhead", "samehead", "headURL", "headtooltip", "headclip",
                "headport", "headlabel", "headtarget", "lhead", "headhref"
            ],
                  label="Head"),
            Group([
                "arrowtail", "tailtarget", "tailhref", "ltail", "sametail",
                "tailport", "taillabel", "tailtooltip", "tailURL", "tailclip"
            ],
                  label="Tail"),
            Group(["URL", "href", "edgeURL", "edgehref"], label="URL"),
            Group([
                "_draw_", "_ldraw_", "_hdraw_", "_tdraw_", "_hldraw_",
                "_tldraw_"
            ],
                  label="Xdot"),
            dock="tab"),
        layout="split",
        id=".splitter"),
                       title="Edge",
                       id="godot.edge",
                       buttons=["OK", "Cancel", "Help"],
                       resizable=True)

    #--------------------------------------------------------------------------
    #  "object" interface:
    #--------------------------------------------------------------------------

    def __init__(self,
                 tailnode_or_ID,
                 headnode_or_ID,
                 directed=False,
                 **traits):
        """ Initialises a new Edge instance.
        """
        if not isinstance(tailnode_or_ID, Node):
            tailnodeID = str(tailnode_or_ID)
            tail_node = Node(tailnodeID)
        else:
            tail_node = tailnode_or_ID

        if not isinstance(headnode_or_ID, Node):
            headnodeID = str(headnode_or_ID)
            head_node = Node(headnodeID)
        else:
            head_node = headnode_or_ID

        self.tail_node = tail_node
        self.head_node = head_node

        if directed:
            self.conn = "->"
        else:
            self.conn = "--"

        super(Edge, self).__init__(**traits)

    def __str__(self):
        """ Returns a string representation of the edge.
        """
        attrs = []
        # Traits to be included in string output have 'graphviz' metadata.
        for trait_name, trait in self.traits(graphviz=True).iteritems():
            # Get the value of the trait for comparison with the default.
            value = getattr(self, trait_name)

            # Only print attribute value pairs if not defaulted.
            # FIXME: Alias/Synced traits default to None.
            if (value != trait.default) and (trait.default is not None):
                # Add quotes to the value if necessary.
                if isinstance(value, basestring):
                    valstr = '"%s"' % value
                else:
                    valstr = str(value)

                attrs.append('%s=%s' % (trait_name, valstr))

        if attrs:
            attrstr = " [%s]" % ", ".join(attrs)
        else:
            attrstr = ""

        edge_str = "%s%s %s %s%s%s;" % (self.tail_node.ID, self.tailport,
                                        self.conn, self.head_node.ID,
                                        self.headport, attrstr)
        return edge_str

    #--------------------------------------------------------------------------
    #  Trait initialisers:
    #--------------------------------------------------------------------------

    def _component_default(self):
        """ Trait initialiser.
        """
        component = Container(auto_size=True, bgcolor="green")
        #        component.tools.append( MoveTool(component) )
        #        component.tools.append( TraitsTool(component) )
        return component

    def _vp_default(self):
        """ Trait initialiser.
        """
        vp = Viewport(component=self.component)
        vp.enable_zoom = True
        vp.tools.append(ViewportPanTool(vp))
        return vp

    #--------------------------------------------------------------------------
    #  Property getters:
    #--------------------------------------------------------------------------

    def _get_name(self):
        """ Property getter.
        """
        if (self.tail_node is not None) and (self.head_node is not None):
            return "%s %s %s" % (self.tail_node.ID, self.conn,
                                 self.head_node.ID)
        else:
            return "Edge"

    #--------------------------------------------------------------------------
    #  Event handlers:
    #--------------------------------------------------------------------------

    @on_trait_change("arrange")
    def arrange_all(self):
        """ Arrange the components of the node using Graphviz.
        """
        # FIXME: Circular reference avoidance.
        import godot.dot_data_parser
        import godot.graph

        graph = godot.graph.Graph(ID="g", directed=True)
        self.conn = "->"
        graph.edges.append(self)

        xdot_data = graph.create(format="xdot")
        #        print "XDOT DATA:", xdot_data

        parser = godot.dot_data_parser.GodotDataParser()
        ndata = xdot_data.replace('\\\n', '')
        tokens = parser.dotparser.parseString(ndata)[0]

        for element in tokens[3]:
            cmd = element[0]
            if cmd == "add_edge":
                cmd, src, dest, opts = element
                self.set(**opts)

#    @on_trait_change("_draw_,_hdraw_")

    def _parse_xdot_directive(self, name, new):
        """ Handles parsing Xdot drawing directives.
        """
        parser = XdotAttrParser()
        components = parser.parse_xdot_data(new)

        # The absolute coordinate of the drawing container wrt graph origin.
        x1 = min([c.x for c in components])
        y1 = min([c.y for c in components])

        print "X1/Y1:", name, x1, y1

        # Components are positioned relative to their container. This
        # function positions the bottom-left corner of the components at
        # their origin rather than relative to the graph.
        #        move_to_origin( components )

        for c in components:
            if isinstance(c, Ellipse):
                component.x_origin -= x1
                component.y_origin -= y1
#                c.position = [ c.x - x1, c.y - y1 ]

            elif isinstance(c, (Polygon, BSpline)):
                print "Points:", c.points
                c.points = [(t[0] - x1, t[1] - y1) for t in c.points]
                print "Points:", c.points

            elif isinstance(c, Text):
                #                font = str_to_font( str(c.pen.font) )
                c.text_x, c.text_y = c.x - x1, c.y - y1

        container = Container(auto_size=True,
                              position=[x1, y1],
                              bgcolor="yellow")

        container.add(*components)

        if name == "_draw_":
            self.drawing = container
        elif name == "_hdraw_":
            self.arrowhead_drawing = container
        else:
            raise

    @on_trait_change("drawing,arrowhead_drawing")
    def _on_drawing(self, object, name, old, new):
        """ Handles the containers of drawing components being set.
        """
        attrs = ["drawing", "arrowhead_drawing"]

        others = [getattr(self, a) for a in attrs \
            if (a != name) and (getattr(self, a) is not None)]

        x, y = self.component.position
        print "POS:", x, y, self.component.position

        abs_x = [d.x + x for d in others]
        abs_y = [d.y + y for d in others]

        print "ABS:", abs_x, abs_y

        # Assume that he new drawing is positioned relative to graph origin.
        x1 = min(abs_x + [new.x])
        y1 = min(abs_y + [new.y])

        print "DRAW:", new.position
        new.position = [new.x - x1, new.y - y1]
        print "DRAW:", new.position

        #        for i, b in enumerate( others ):
        #            self.drawing.position = [100, 100]
        #            self.drawing.request_redraw()
        #            print "OTHER:", b.position, abs_x[i] - x1
        #            b.position = [ abs_x[i] - x1, abs_y[i] - y1 ]
        #            b.x = 50
        #            b.y = 50
        #            print "OTHER:", b.position, abs_x[i], x1

        #        for attr in attrs:
        #            if attr != name:
        #                if getattr(self, attr) is not None:
        #                    drawing = getattr(self, attr)
        #                    drawing.position = [50, 50]

        if old is not None:
            self.component.remove(old)
        if new is not None:
            self.component.add(new)

        print "POS NEW:", self.component.position
        self.component.position = [x1, y1]
        print "POS NEW:", self.component.position
        self.component.request_redraw()
        print "POS NEW:", self.component.position
Example #28
0
class Marker(Artist):
    """
    An interface class between the higher level artists and the marker
    primitive that needs to talk to the renderers
    """
    _marker = traits.Instance(MarkerPrimitive, ())
    locs = Array('d')
    path = Instance(Path, ())
    model = mtraits.Model
    sequence = 'markers'
    size = Float(1.0)  # size of the marker in points

    def __init__(self):
        """
        The model is a function taking Nx2->Nx2.  This is where the
        nonlinear transformation can be used
        """
        Artist.__init__(self)
        self._markerid = primitiveID()

    def _locs_default(self):
        return npy.array([[0, 1], [0, 1]], npy.float_)

    def _path_default(self):
        bounds = npy.array([-0.5, -0.5, 1, 1]) * self.size
        return Rectangle().set(bounds=bounds)

    def _path_changed(self, old, new):
        if self.renderer is None:
            # we can't sync up to the underlying path yet
            return
        print 'MARKER _path_changed', self.path._path.pathdata, self._marker.path.pathdata
        old.sync_trait('_path', self._marker, 'path', remove=True)
        new.sync_trait('_path', self._marker, 'path', mutual=False)

    def _update_marker(self):
        'sync the Marker traits with the marker primitive'
        if self.renderer is None:
            # we can't sync up to the underlying path yet
            return

        # sync up the marker affine
        self.path.sync_trait('_path', self._marker, 'path', mutual=False)
        self._marker.affine.follow(self.affine.vec6)
        self.affine.on_trait_change(self._marker.affine.follow, 'vec6')
        self._update_locs()

        print 'MARKER _update_marker', self.path._path.pathdata, self._marker.path.pathdata

    def _update_locs(self):
        print 'MARKER: update markerdata'
        xy = self.locs
        if self.model is not None:
            xy = self.model(xy)

        self._marker.locs = xy

    def draw(self):
        if self.renderer is None or not self.visible: return
        Artist.draw(self)
        self.renderer.render_marker(self._markerid)

    def _renderer_changed(self, old, new):
        # we must make sure the contained artist gets the callback
        # first so we can update the path primitives properly
        self.path._renderer_changed(old, new)
        if old is not None:
            del old.markerd[self._markerid]

        if new is None: return

        print 'MARKER renderer_changed; updating'
        self._marker = renderer.new_marker_primitive()
        new.markerd[self._markerid] = self._marker
        self._update_marker()

    def _model_changed(self, old, new):
        self._update_locs()

    def _locs_changed(self, old, new):
        if len(new.shape) != 2:
            raise ValueError('new must be nx2 array')
        self._update_locs()
Example #29
0
class Variables(HasTraits):
    vars_pool = {}
    vars_list = List()
    vars_table_list = List(
    )  # a list version of vars_pool maintained for the TabularEditor
    vars_table_list_update_time = Float(0)

    sample_number = Int(0)
    sample_count = Int(0)
    max_samples = Int(20000)

    start_time = time.time()

    add_var_event = Event()

    expressions = List()

    vars_table_update = Bool(True)

    clear_button = Button('Clear')
    view = View(HSplit(
        Item(name='clear_button', show_label=False),
        Item(name='max_samples', label='Max samples'),
        Item(name='sample_count', label='Samples'),
        Item(name='vars_table_update', label='Update variables view')),
                Item(name='vars_table_list',
                     editor=TabularEditor(adapter=VariableTableAdapter(),
                                          editable=False,
                                          dclicked="add_var_event"),
                     resizable=True,
                     show_label=False),
                title='Variable view',
                resizable=True,
                width=.7,
                height=.2)

    def new_expression(self, expr):
        new_expression = Expression(self, expr)
        self.expressions.append(new_expression)
        return new_expression

    def update_variables(self, data_dict):
        """
        Receive a dict of variables from a decoder and integrate them
        into our global variable pool.
    """
        self.sample_number += 1

        # We update into a new dict rather than vars_pool due to pythons pass by reference
        # behaviour, we need a fresh object to put on our array
        new_vars_pool = {}
        new_vars_pool.update(self.vars_pool)
        new_vars_pool.update(data_dict)
        new_vars_pool.update({
            'sample_num': self.sample_number,
            'system_time': time.time(),
            'time': time.time() - self.start_time
        })
        if '' in new_vars_pool:
            del new_vars_pool['']  # weed out undesirables

        self.vars_list.append(new_vars_pool)
        self.update_vars_list()

    def update_vars_list(self):
        self.vars_pool = self.vars_list[-1]

        if time.time() - self.vars_table_list_update_time > 0.2:
            self.vars_table_list_update_time = time.time()
            self.update_vars_table()

        self.sample_count = len(self.vars_list)
        if self.sample_count > self.max_samples:
            self.vars_list = self.vars_list[-self.max_samples:]
            self.sample_count = self.max_samples

    @on_trait_change('clear_button')
    def clear(self):
        """ Clear all recorded data. """
        self.sample_number = 0
        self.vars_list = [{}]
        self.update_vars_list()
        self.update_vars_table()
        self.start_time = time.time()

        for expression in self.expressions:
            expression.clear_cache()

    def save_data_set(self, filename):
        fp = open(filename, 'wb')
        pickle.dump(self.vars_list, fp, True)
        fp.close()

    def open_data_set(self, filename):
        fp = open(filename, 'rb')
        self.vars_list = pickle.load(fp)
        fp.close()

        self.update_vars_list()
        self.update_vars_table()
        self.sample_number = self.sample_count
        # spoof start time so that we start where we left off
        self.start_time = time.time() - self.vars_list[-1]['time']

    def update_vars_table(self):
        if self.vars_table_update:
            vars_list_unsorted = [
                (name, repr(val))
                for (name, val) in list(self.vars_pool.iteritems())
            ]
            self.vars_table_list = sorted(vars_list_unsorted,
                                          key=(lambda x: x[0].lower()))

    def test_expr(self, expr):
        is_ok = (True, '')
        try:
            eval(expr, expression_context, self.vars_pool)
        except Exception as e:
            is_ok = (False, repr(e))
        return is_ok

    def _eval_expr(self, expr, vars_pool=None):
        """
        Returns the value of a python expression evaluated with 
        the variables in the pool in scope. Used internally by
        Expression. Users should use Expression instead as it
        has caching etc.
    """
        if vars_pool == None:
            vars_pool = self.vars_pool

        try:
            data = eval(expr, expression_context, vars_pool)
        except:
            data = None
        return data

    def bound_array(self, first, last):
        if first < 0:
            first += self.sample_number
            if first < 0:
                first = 0
        if last and last < 0:
            last += self.sample_number
        if last == None:
            last = self.sample_number

        return (first, last)

    def _get_array(self, expr, first=0, last=None):
        """
        Returns an array of tuples containing the all the values of an
        the supplied expression and the sample numbers and times corresponding to
        these values. Used internally by Expression, users should use Expression
        directly as it has caching etc.
    """

        first, last = self.bound_array(first, last)
        if expr in self.vars_pool:
            data = [vs.get(expr) for vs in self.vars_list[first:last]]
        else:
            data = [
                self._eval_expr(expr, vs) for vs in self.vars_list[first:last]
            ]
        data = [0.0 if d is None else d for d in data]

        data_array = numpy.array(data)
        return data_array
Example #30
0
class TemplatePicker(HasTraits):
    template = Array
    CC = Array
    peaks = List
    zero = Int(0)
    tmp_size = Range(low=2, high=512, value=64, cols=4)
    max_pos_x = Int(1023)
    max_pos_y = Int(1023)
    top = Range(low='zero', high='max_pos_x', value=20, cols=4)
    left = Range(low='zero', high='max_pos_y', value=20, cols=4)
    is_square = Bool
    img_plot = Instance(Plot)
    tmp_plot = Instance(Plot)
    findpeaks = Button
    peak_width = Range(low=2, high=200, value=10)
    tab_selected = Event
    ShowCC = Bool
    img_container = Instance(Component)
    container = Instance(Component)
    colorbar = Instance(Component)
    numpeaks_total = Int(0)
    numpeaks_img = Int(0)
    OK_custom = OK_custom_handler
    cbar_selection = Instance(RangeSelection)
    cbar_selected = Event
    thresh = Trait(None, None, List, Tuple, Array)
    thresh_upper = Float(1.0)
    thresh_lower = Float(0.0)
    numfiles = Int(1)
    img_idx = Int(0)
    tmp_img_idx = Int(0)

    csr = Instance(BaseCursorTool)

    traits_view = View(HFlow(
        VGroup(Item("img_container",
                    editor=ComponentEditor(),
                    show_label=False),
               Group(
                   Spring(),
                   Item("ShowCC",
                        editor=BooleanEditor(),
                        label="Show cross correlation image")),
               label="Original image",
               show_border=True,
               trait_modified="tab_selected"),
        VGroup(
            Group(HGroup(
                Item("left", label="Left coordinate", style="custom"),
                Item("top", label="Top coordinate", style="custom"),
            ),
                  Item("tmp_size", label="Template size", style="custom"),
                  Item("tmp_plot",
                       editor=ComponentEditor(height=256, width=256),
                       show_label=False,
                       resizable=True),
                  label="Template",
                  show_border=True),
            Group(Item("peak_width", label="Peak width", style="custom"),
                  Group(
                      Spring(),
                      Item("findpeaks",
                           editor=ButtonEditor(label="Find Peaks"),
                           show_label=False),
                      Spring(),
                  ),
                  HGroup(
                      Item("thresh_lower",
                           label="Threshold Lower Value",
                           editor=TextEditor(evaluate=float,
                                             format_str='%1.4f')),
                      Item("thresh_upper",
                           label="Threshold Upper Value",
                           editor=TextEditor(evaluate=float,
                                             format_str='%1.4f')),
                  ),
                  HGroup(
                      Item("numpeaks_img",
                           label="Number of Cells selected (this image)",
                           style='readonly'),
                      Spring(),
                      Item("numpeaks_total", label="Total", style='readonly'),
                      Spring(),
                  ),
                  label="Peak parameters",
                  show_border=True),
        )),
                       buttons=[
                           Action(name='OK',
                                  enabled_when='numpeaks_total > 0'),
                           CancelButton
                       ],
                       title="Template Picker",
                       handler=OK_custom,
                       kind='livemodal',
                       key_bindings=key_bindings,
                       width=960,
                       height=600)

    def __init__(self, signal_instance):
        super(TemplatePicker, self).__init__()
        try:
            import cv
        except:
            print "OpenCV unavailable.  Can't do cross correlation without it.  Aborting."
            return None
        self.OK_custom = OK_custom_handler()
        self.sig = signal_instance
        if not hasattr(self.sig.mapped_parameters, "original_files"):
            self.sig.data = np.atleast_3d(self.sig.data)
            self.titles = [self.sig.mapped_parameters.name]
        else:
            self.numfiles = len(
                self.sig.mapped_parameters.original_files.keys())
            self.titles = self.sig.mapped_parameters.original_files.keys()
        tmp_plot_data = ArrayPlotData(
            imagedata=self.sig.data[self.top:self.top + self.tmp_size,
                                    self.left:self.left + self.tmp_size,
                                    self.img_idx])
        tmp_plot = Plot(tmp_plot_data, default_origin="top left")
        tmp_plot.img_plot("imagedata", colormap=jet)
        tmp_plot.aspect_ratio = 1.0
        self.tmp_plot = tmp_plot
        self.tmp_plotdata = tmp_plot_data
        self.img_plotdata = ArrayPlotData(
            imagedata=self.sig.data[:, :, self.img_idx])
        self.img_container = self._image_plot_container()

        self.crop_sig = None

    def render_image(self):
        plot = Plot(self.img_plotdata, default_origin="top left")
        img = plot.img_plot("imagedata", colormap=gray)[0]
        plot.title = "%s of %s: " % (self.img_idx + 1,
                                     self.numfiles) + self.titles[self.img_idx]
        plot.aspect_ratio = float(self.sig.data.shape[1]) / float(
            self.sig.data.shape[0])

        #if not self.ShowCC:
        csr = CursorTool(img,
                         drag_button='left',
                         color='white',
                         line_width=2.0)
        self.csr = csr
        csr.current_position = self.left, self.top
        img.overlays.append(csr)

        # attach the rectangle tool
        plot.tools.append(PanTool(plot, drag_button="right"))
        zoom = ZoomTool(plot,
                        tool_mode="box",
                        always_on=False,
                        aspect_ratio=plot.aspect_ratio)
        plot.overlays.append(zoom)
        self.img_plot = plot
        return plot

    def render_scatplot(self):
        peakdata = ArrayPlotData()
        peakdata.set_data("index", self.peaks[self.img_idx][:, 0])
        peakdata.set_data("value", self.peaks[self.img_idx][:, 1])
        peakdata.set_data("color", self.peaks[self.img_idx][:, 2])
        scatplot = Plot(peakdata,
                        aspect_ratio=self.img_plot.aspect_ratio,
                        default_origin="top left")
        scatplot.plot(
            ("index", "value", "color"),
            type="cmap_scatter",
            name="my_plot",
            color_mapper=jet(DataRange1D(low=0.0, high=1.0)),
            marker="circle",
            fill_alpha=0.5,
            marker_size=6,
        )
        scatplot.x_grid.visible = False
        scatplot.y_grid.visible = False
        scatplot.range2d = self.img_plot.range2d
        self.scatplot = scatplot
        self.peakdata = peakdata
        return scatplot

    def _image_plot_container(self):
        plot = self.render_image()

        # Create a container to position the plot and the colorbar side-by-side
        self.container = OverlayPlotContainer()
        self.container.add(plot)
        self.img_container = HPlotContainer(use_backbuffer=False)
        self.img_container.add(self.container)
        self.img_container.bgcolor = "white"

        if self.numpeaks_img > 0:
            scatplot = self.render_scatplot()
            self.container.add(scatplot)
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
        return self.img_container

    def draw_colorbar(self):
        scatplot = self.scatplot
        cmap_renderer = scatplot.plots["my_plot"][0]
        selection = ColormappedSelectionOverlay(cmap_renderer,
                                                fade_alpha=0.35,
                                                selection_type="range")
        cmap_renderer.overlays.append(selection)
        if self.thresh is not None:
            cmap_renderer.color_data.metadata['selections'] = self.thresh
            cmap_renderer.color_data.metadata_changed = {
                'selections': self.thresh
            }
        # Create the colorbar, handing in the appropriate range and colormap
        colormap = scatplot.color_mapper
        colorbar = ColorBar(
            index_mapper=LinearMapper(range=DataRange1D(low=0.0, high=1.0)),
            orientation='v',
            resizable='v',
            width=30,
            padding=20)
        colorbar_selection = RangeSelection(component=colorbar)
        colorbar.tools.append(colorbar_selection)
        ovr = colorbar.overlays.append(
            RangeSelectionOverlay(component=colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray",
                                  metadata_name='selections'))
        #ipshell('colorbar, colorbar_selection and ovr available:')
        self.cbar_selection = colorbar_selection
        self.cmap_renderer = cmap_renderer
        colorbar.plot = cmap_renderer
        colorbar.padding_top = scatplot.padding_top
        colorbar.padding_bottom = scatplot.padding_bottom
        self.colorbar = colorbar
        return colorbar

    @on_trait_change('ShowCC')
    def toggle_cc_view(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.img_idx], self.sig.data[:, :, self.img_idx])
            self.img_plotdata.set_data("imagedata", self.CC)
        else:
            self.img_plotdata.set_data("imagedata",
                                       self.sig.data[:, :, self.img_idx])
        self.redraw_plots()

    @on_trait_change("img_idx")
    def update_img_depth(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.img_idx], self.sig.data[:, :, self.img_idx])
            self.img_plotdata.set_data("imagedata", self.CC)
        else:
            self.img_plotdata.set_data("imagedata",
                                       self.sig.data[:, :, self.img_idx])
        self.img_plot.title = "%s of %s: " % (
            self.img_idx + 1, self.numfiles) + self.titles[self.img_idx]
        self.redraw_plots()

    @on_trait_change('tmp_size')
    def update_max_pos(self):
        max_pos_x = self.sig.data.shape[0] - self.tmp_size - 1
        if self.left > max_pos_x: self.left = max_pos_x
        self.max_pos_x = max_pos_x
        max_pos_y = self.sig.data.shape[1] - self.tmp_size - 1
        if self.top > max_pos_y: self.top = max_pos_y
        self.max_pos_y = max_pos_y
        return

    def increase_img_idx(self, info):
        if self.img_idx == (self.numfiles - 1):
            self.img_idx = 0
        else:
            self.img_idx += 1

    def decrease_img_idx(self, info):
        if self.img_idx == 0:
            self.img_idx = self.numfiles - 1
        else:
            self.img_idx -= 1

    @on_trait_change('left, top')
    def update_csr_position(self):
        self.csr.current_position = self.left, self.top

    @on_trait_change('csr:current_position')
    def update_top_left(self):
        self.left, self.top = self.csr.current_position

    @on_trait_change('left, top, tmp_size')
    def update_tmp_plot(self):
        self.tmp_plotdata.set_data(
            "imagedata",
            self.sig.data[self.top:self.top + self.tmp_size,
                          self.left:self.left + self.tmp_size, self.img_idx])
        grid_data_source = self.tmp_plot.range2d.sources[0]
        grid_data_source.set_data(np.arange(self.tmp_size),
                                  np.arange(self.tmp_size))
        self.tmp_img_idx = self.img_idx
        return

    @on_trait_change('left, top, tmp_size')
    def update_CC(self):
        if self.ShowCC:
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.tmp_img_idx], self.sig.data[:, :,
                                                               self.img_idx])
            self.img_plotdata.set_data("imagedata", self.CC)
            grid_data_source = self.img_plot.range2d.sources[0]
            grid_data_source.set_data(np.arange(self.CC.shape[1]),
                                      np.arange(self.CC.shape[0]))
        if self.numpeaks_total > 0:
            self.peaks = [np.array([[0, 0, -1]])]

    @on_trait_change('cbar_selection:selection')
    def update_thresh(self):
        try:
            thresh = self.cbar_selection.selection
            self.thresh = thresh
            self.cmap_renderer.color_data.metadata['selections'] = thresh
            self.thresh_lower = thresh[0]
            self.thresh_upper = thresh[1]
            #cmap_renderer.color_data.metadata['selection_masks']=self.thresh
            self.cmap_renderer.color_data.metadata_changed = {
                'selections': thresh
            }
            self.container.request_redraw()
            self.img_container.request_redraw()
        except:
            pass

    @on_trait_change('thresh_upper,thresh_lower')
    def manual_thresh_update(self):
        self.thresh = [self.thresh_lower, self.thresh_upper]
        self.cmap_renderer.color_data.metadata['selections'] = self.thresh
        self.cmap_renderer.color_data.metadata_changed = {
            'selections': self.thresh
        }
        self.container.request_redraw()
        self.img_container.request_redraw()

    @on_trait_change('peaks,cbar_selection:selection,img_idx')
    def calc_numpeaks(self):
        try:
            thresh = self.cbar_selection.selection
            self.thresh = thresh
        except:
            thresh = []
        if thresh == [] or thresh == () or thresh == None:
            thresh = (0, 1)
        self.numpeaks_total = int(
            np.sum([
                np.sum(
                    np.ma.masked_inside(self.peaks[i][:, 2], thresh[0],
                                        thresh[1]).mask)
                for i in xrange(len(self.peaks))
            ]))
        try:
            self.numpeaks_img = int(
                np.sum(
                    np.ma.masked_inside(self.peaks[self.img_idx][:, 2],
                                        thresh[0], thresh[1]).mask))
        except:
            self.numpeaks_img = 0

    @on_trait_change('findpeaks')
    def locate_peaks(self):
        from hyperspy import peak_char as pc
        peaks = []
        for idx in xrange(self.numfiles):
            self.CC = cv_funcs.xcorr(
                self.sig.data[self.top:self.top + self.tmp_size,
                              self.left:self.left + self.tmp_size,
                              self.tmp_img_idx], self.sig.data[:, :, idx])
            # peak finder needs peaks greater than 1.  Multiply by 255 to scale them.
            pks = pc.two_dim_findpeaks(self.CC * 255,
                                       peak_width=self.peak_width,
                                       medfilt_radius=None)
            pks[:, 2] = pks[:, 2] / 255.
            peaks.append(pks)
        self.peaks = peaks

    def mask_peaks(self, idx):
        thresh = self.cbar_selection.selection
        if thresh == []:
            thresh = (0, 1)
        mpeaks = np.ma.asarray(self.peaks[idx])
        mpeaks[:, 2] = np.ma.masked_outside(mpeaks[:, 2], thresh[0], thresh[1])
        return mpeaks

    @on_trait_change("peaks")
    def redraw_plots(self):
        oldplot = self.img_plot
        self.container.remove(oldplot)
        newplot = self.render_image()
        self.container.add(newplot)
        self.img_plot = newplot

        try:
            # if these haven't been created before, this will fail.  wrap in try to prevent that.
            oldscat = self.scatplot
            self.container.remove(oldscat)
            oldcolorbar = self.colorbar
            self.img_container.remove(oldcolorbar)
        except:
            pass

        if self.numpeaks_img > 0:
            newscat = self.render_scatplot()
            self.container.add(newscat)
            self.scatplot = newscat
            colorbar = self.draw_colorbar()
            self.img_container.add(colorbar)
            self.colorbar = colorbar

        self.container.request_redraw()
        self.img_container.request_redraw()

    def crop_cells_stack(self):
        from eelslab.signals.aggregate import AggregateCells
        if self.numfiles == 1:
            self.crop_sig = self.crop_cells()
            return
        else:
            crop_agg = []
            for idx in xrange(self.numfiles):
                crop_agg.append(self.crop_cells(idx))
            self.crop_sig = AggregateCells(*crop_agg)
            return

    def crop_cells(self, idx=0):
        print "cropping cells..."
        from hyperspy.signals.image import Image
        # filter the peaks that are outside the selected threshold
        peaks = np.ma.compress_rows(self.mask_peaks(idx))
        tmp_sz = self.tmp_size
        data = np.zeros((tmp_sz, tmp_sz, peaks.shape[0]))
        if not hasattr(self.sig.mapped_parameters, "original_files"):
            parent = self.sig
        else:
            parent = self.sig.mapped_parameters.original_files[
                self.titles[idx]]
        for i in xrange(peaks.shape[0]):
            # crop the cells from the given locations
            data[:, :, i] = self.sig.data[peaks[i, 1]:peaks[i, 1] + tmp_sz,
                                          peaks[i,
                                                0]:peaks[i, 0] + tmp_sz, idx]
            crop_sig = Image({
                'data': data,
                'mapped_parameters': {
                    'name': 'Cropped cells from %s' % self.titles[idx],
                    'record_by': 'image',
                    'locations': peaks,
                    'parent': parent,
                }
            })
        return crop_sig
        # attach a class member that has the locations from which the images were cropped
        print "Complete.  "