Beispiel #1
0
class CorrelationData(HasTraits):
    """Holds correlation data
    """
    cr_plot = Instance(Plot, ())
    corr_plot = Instance(Plot, ())
    plotdata = Instance(ArrayPlotData)
    count_rate = Array
    correlation = Array
    correlation_avg = Array
    index = Int

    traits_view = View(Item('cr_plot',
                            editor=ComponentEditor(),
                            show_label=False),
                       Item('corr_plot',
                            editor=ComponentEditor(),
                            show_label=False),
                       width=400,
                       height=800,
                       resizable=True)

    def _plotdata_default(self):
        x = linspace(-14, 14, 100)
        y = x * 0
        return ArrayPlotData(time=x, cr=y, lag=x, corr=y, avg=y)

    def __init__(self, **kw):
        super(CorrelationData, self).__init__(**kw)
        plot = Plot(self.plotdata)
        plot2 = Plot(self.plotdata)
        plot.plot(("time", "cr"), type="line", color="blue")
        plot2.plot(("lag", "corr"), type="line", color="green")
        plot2.plot(("lag", "avg"), type="line", color="red")
        plot2.index_scale = 'log'
        self.cr_plot = plot
        self.corr_plot = plot2

    def _count_rate_changed(self, data):
        self.plotdata.set_data('time', data[:, 0])
        self.plotdata.set_data('cr', data[:, 1])

    def _correlation_changed(self, data):
        self.plotdata.set_data('lag', data[:, 0])
        self.plotdata.set_data('corr', data[:, 1])

    def _correlation_avg_changed(self, data):
        self.plotdata.set_data('lag', data[:, 0])
        self.plotdata.set_data('avg', data[:, 1])
class Demo(HasTraits):
    plot = Instance(Component)

    traits_view = View(Group(Item('plot',
                                  editor=ComponentEditor(size=size),
                                  show_label=False),
                             orientation="vertical"),
                       resizable=True,
                       title=title)

    def _plot_default(self):
        return _create_plot_component()
Beispiel #3
0
class EnableDemo(HasTraits):
    box = Instance(DrawLineComponent)

    view = View(HGroup(Item("object.box.line_width"),
                       Item("object.box.line_color", style="custom")),
                Item("box", editor=ComponentEditor(), show_label=False),
                resizable=True,
                width=500,
                height=500,
                title="Draw Lines")

    def __init__(self, **traits):
        super(EnableDemo, self).__init__(**traits)
        self.box = DrawLineComponent()
class StarDesign(HasTraits):
    box = Instance(StarComponent)

    view = View(HGroup(Item("object.box.edges", label="顶角数"),
                       Item("object.box.star_color", label="颜色")),
                Item("box", editor=ComponentEditor(), show_label=False),
                resizable=True,
                width=600,
                height=400,
                title="星空设计")

    def __init__(self, **traits):
        super(StarDesign, self).__init__(**traits)
        self.box = StarComponent()
class DoublePendulumGUI(HasTraits):
    pendulum = Instance(DoublePendulum)
    m1 = Range(1.0, 10.0, 2.0)
    m2 = Range(1.0, 10.0, 2.0)
    l1 = Range(1.0, 10.0, 2.0)
    l2 = Range(1.0, 10.0, 2.0)
    positions = Tuple
    index = Int(0)
    timer = Instance(Timer)
    graph = Instance(DoublePendulumComponent)
    animation = Bool(True)

    view = View(HGroup(
        VGroup(
            Item("m1"),
            Item("m2"),
            Item("l1"),
            Item("l2"),
        ),
        Item("graph", editor=ComponentEditor(), show_label=False),
    ),
                width=600,
                height=400,
                title="双摆演示",
                resizable=True)

    def __init__(self):
        self.pendulum = DoublePendulum(self.m1, self.m2, self.l1, self.l2)
        self.pendulum.init_status[:] = 1.0, 2.0, 0, 0
        self.graph = DoublePendulumComponent()
        self.graph.gui = self
        self.timer = Timer(10, self.on_timer)

    def on_timer(self, *args):
        if len(self.positions) == 0 or self.index == len(self.positions[0]):
            self.pendulum.m1 = self.m1
            self.pendulum.m2 = self.m2
            self.pendulum.l1 = self.l1
            self.pendulum.l2 = self.l2
            if self.animation:
                self.positions = double_pendulum_odeint(
                    self.pendulum, 0, 0.5, 0.02)
            else:
                self.positions = double_pendulum_odeint(
                    self.pendulum, 0, 0.00001, 0.00001)
            self.index = 0
        self.graph.p = tuple(array[self.index] for array in self.positions)
        self.index += 1
        self.graph.request_redraw()
Beispiel #6
0
 def default_traits_view(self):
     view = View(
         VGroup(
             HGroup(
                 Item("current_map", label=u"颜色映射", editor=EnumEditor(name="object.color_maps")),
                 Item("reverse_map", label=u"反转颜色"),
                 Item("position", label=u"位置", style="readonly"),
             ),
             Item("plot", show_label=False, editor=ComponentEditor()),
         ),
         resizable = True,
         width = 550, height = 300,
         title = u"Mandelbrot观察器"
     )
     return view
Beispiel #7
0
class FigureInspectorData(FigureInspector):
    """See :class:`Figure`. In adition.. defines a filename attribute.. ta load images from file
    """
    filename = File()

    traits_view = View('filename',
                       Group(Item('container',
                                  editor=ComponentEditor(size=size,
                                                         bgcolor=bg_color),
                                  show_label=False),
                             orientation="vertical"),
                       resizable=True)

    def _filename_changed(self, new):
        image = ImageData.fromfile(new)
        self.plot_image(image._data)
class FloodFillDemo(HasTraits):
    lo_diff = Array(np.float, (1, 4))
    hi_diff = Array(np.float, (1, 4))
    plot = Instance(Plot)
    point = Tuple((0, 0))
    option = Trait(u"以邻点为标准-4联通", Options)

    view = View(VGroup(
        VGroup(Item("lo_diff", label=u"负方向范围"), Item("hi_diff",
                                                     label=u"正方向范围"),
               Item("option", label=u"算法标志")),
        Item("plot", editor=ComponentEditor(), show_label=False),
    ),
                title=u"FloodFill Demo控制面板",
                width=500,
                height=450,
                resizable=True)

    def __init__(self, *args, **kwargs):
        self.lo_diff.fill(5)
        self.hi_diff.fill(5)
        self.img = cv.imread("lena.jpg")
        self.data = ArrayPlotData(img=self.img[:, :, ::-1])
        w = self.img.size().width
        h = self.img.size().height
        self.plot = Plot(self.data, padding=10, aspect_ratio=float(w) / h)
        self.plot.x_axis.visible = False
        self.plot.y_axis.visible = False
        self.imgplot = self.plot.img_plot("img", origin="top left")[0]
        self.imgplot.interpolation = "nearest"
        self.imgplot.overlays.append(
            PointPicker(application=self, component=self.imgplot))

        self.on_trait_change(self.redraw, "point,lo_diff,hi_diff,option")

    def redraw(self):
        img = self.img.clone()
        cv.floodFill(img,
                     cv.Point(*self.point),
                     cv.Scalar(255, 0, 0, 255),
                     loDiff=cv.asScalar(self.lo_diff[0]),
                     upDiff=cv.asScalar(self.hi_diff[0]),
                     flags=self.option_)
        self.data["img"] = img[:, :, ::-1]
Beispiel #9
0
class ScatterPlotTraits(HasTraits):

    plot = Instance(Plot)
    color = ColorTrait("blue")
    marker = marker_trait
    marker_size = Int(4)
    x = Array()
    y = Array()

    traits_view = View(Group(Item('color', label="Color", style="custom"),
                             Item('marker', label="Marker"),
                             Item('marker_size', label="Size"),
                             Item('plot',
                                  editor=ComponentEditor(),
                                  show_label=False),
                             orientation="vertical"),
                       width=800,
                       height=600,
                       resizable=True,
                       title="Chaco Plot")

    def __init__(self):
        super(ScatterPlotTraits, self).__init__()
        x = linspace(-14, 14, 10)
        y = sin(x) * x**3
        plotdata = ArrayPlotData(x=x, y=y)
        plot = Plot(plotdata)
        self.x = x
        self.y = y

        self.renderer = plot.plot(("x", "y"), type="scatter", color="blue")[0]
        self.plot = plot

    def _color_changed(self):
        self.renderer.color = self.color

    def _marker_changed(self):
        self.renderer.marker = self.marker

    def _marker_size_changed(self):
        self.renderer.marker_size = self.marker_size
class MatrixViewer(HasTraits):

    tplot = Instance(Plot)
    plot = Instance(Component)
    custtool = Instance(CustomTool)
    colorbar = Instance(ColorBar)

    edge_para = Any
    data_name = Enum("a", "b")

    fro = Int
    to = Int
    data = None
    val = Float

    traits_view = View(Group(Item('plot',
                                  editor=ComponentEditor(size=(800, 600)),
                                  show_label=False),
                             HGroup(
                                 Item('fro',
                                      label="From",
                                      style='readonly',
                                      springy=True),
                                 Item('to',
                                      label="To",
                                      style='readonly',
                                      springy=True),
                                 Item('val',
                                      label="Value",
                                      style='readonly',
                                      springy=True),
                             ),
                             orientation="vertical"),
                       Item('data_name', label="Image data"),
                       handler=CustomHandler(),
                       resizable=True,
                       title="Matrix Viewer")

    def __init__(self, data, **traits):
        """ Data is a nxn numpy array """
        super(HasTraits, self).__init__(**traits)

        self.data_name = data.keys()[0]
        self.data = data
        self.plot = self._create_plot_component()

        # set trait notification on customtool
        self.custtool.on_trait_change(self._update_fields, "xval")
        self.custtool.on_trait_change(self._update_fields, "yval")

    def _data_name_changed(self, old, new):
        self.pd.set_data("imagedata", self.data[self.data_name])
        self.my_plot.set_value_selection((0, 2))

    def _update_fields(self):
        from numpy import trunc

        # map mouse location to array index
        frotmp = int(trunc(self.custtool.yval))
        totmp = int(trunc(self.custtool.xval))

        # check if within range
        sh = self.data[self.data_name].shape
        # assume matrix whose shape is (# of rows, # of columns)
        if frotmp >= 0 and frotmp < sh[0] and totmp >= 0 and totmp < sh[1]:
            self.fro = frotmp
            self.to = totmp
            self.val = self.data[self.data_name][self.fro, self.to]

    def _create_plot_component(self):

        # Create a plot data object and give it this data
        self.pd = ArrayPlotData()
        self.pd.set_data("imagedata", self.data[self.data_name])

        # Create the plot
        self.tplot = Plot(self.pd, default_origin="top left")
        self.tplot.x_axis.orientation = "top"
        self.tplot.img_plot(
            "imagedata",
            name="my_plot",
            #xbounds=(0,10),
            #ybounds=(0,10),
            colormap=jet)

        # Tweak some of the plot properties
        self.tplot.title = "Matrix"
        self.tplot.padding = 50

        # Right now, some of the tools are a little invasive, and we need the
        # actual CMapImage object to give to them
        self.my_plot = self.tplot.plots["my_plot"][0]

        # Attach some tools to the plot
        self.tplot.tools.append(PanTool(self.tplot))
        zoom = ZoomTool(component=self.tplot, tool_mode="box", always_on=False)
        self.tplot.overlays.append(zoom)

        # my custom tool to get the connection information
        self.custtool = CustomTool(self.tplot)
        self.tplot.tools.append(self.custtool)

        # Create the colorbar, handing in the appropriate range and colormap
        colormap = self.my_plot.color_mapper
        self.colorbar = ColorBar(
            index_mapper=LinearMapper(range=colormap.range),
            color_mapper=colormap,
            plot=self.my_plot,
            orientation='v',
            resizable='v',
            width=30,
            padding=20)

        self.colorbar.padding_top = self.tplot.padding_top
        self.colorbar.padding_bottom = self.tplot.padding_bottom

        # create a range selection for the colorbar
        self.range_selection = RangeSelection(component=self.colorbar)
        self.colorbar.tools.append(self.range_selection)
        self.colorbar.overlays.append(
            RangeSelectionOverlay(component=self.colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray"))

        # we also want to the range selection to inform the cmap plot of
        # the selection, so set that up as well
        self.range_selection.listeners.append(self.my_plot)

        # Create a container to position the plot and the colorbar side-by-side
        container = HPlotContainer(use_backbuffer=True)
        container.add(self.tplot)
        container.add(self.colorbar)
        container.bgcolor = "white"

        return container
class FFT2Demo(HasTraits):
    plot = Instance(HPlotContainer)
    filtered_img = Array()
    timer = Instance(Timer)
    need_redraw = Bool(True)

    traits_view = View(
        Group(Item('plot', editor=ComponentEditor(), show_label=False),
              orientation="vertical"),
        resizable=True,
        title=u"二维傅立叶变换滤波演示",
        width=260 * 3,
        height=260,
    )

    def __init__(self):
        #读入图像
        img = cv.imread("lena_full.jpg")
        img2 = cv.Mat()
        cv.cvtColor(img, img2, cv.CV_BGR2GRAY)
        img = cv.Mat()
        cv.resize(img2, img, cv.Size(N, N))
        self.fimg = fft.fft2(img[:])  # 图像的频域信号
        mag_img = np.log10(np.abs(self.fimg))

        # 创建计算用图像
        filtered_img = np.zeros((N, N), dtype=np.float)
        self.mask = np.zeros((N, N), dtype=np.float)
        self.mask_img = cv.asMat(self.mask)  # 在self.mask上绘制多边形用的图像

        # 创建数据源
        self.data = ArrayPlotData(mag_img=fft.fftshift(mag_img),
                                  filtered_img=filtered_img,
                                  mask_img=self.mask)

        # 创建三个图像绘制框以及容器
        meg_plot, img = self.make_image_plot("mag_img")
        mask_plot, _ = self.make_image_plot("mask_img")
        filtered_plot, _ = self.make_image_plot("filtered_img")
        self.plot = HPlotContainer(meg_plot, mask_plot, filtered_plot)

        # 创建套索工具
        lasso_selection = LassoSelection(component=img)
        lasso_overlay = LassoOverlay(lasso_selection=lasso_selection,
                                     component=img,
                                     selection_alpha=0.3)
        img.tools.append(lasso_selection)
        img.overlays.append(lasso_overlay)
        self.lasso_selection = lasso_selection

        # 监听套索工具的事件、开启时钟事件
        lasso_selection.on_trait_change(self.lasso_updated,
                                        "disjoint_selections")
        self.timer = Timer(50, self.on_timer)

    def make_image_plot(self, img_data):
        p = Plot(self.data, aspect_ratio=1)
        p.x_axis.visible = False
        p.y_axis.visible = False
        p.padding = [1, 1, 1, 1]
        return p, p.img_plot(img_data, colormap=gray, origin="top left")[0]

    def lasso_updated(self):
        self.need_redraw = True

    def on_timer(self, *args):
        if not self.need_redraw: return
        self.need_redraw = False

        self.mask.fill(0)
        length = len(self.lasso_selection.dataspace_points)
        if length == 0: return

        def convert_poly(poly):
            tmp = cv.asvector_Point2i(poly)
            return cv.vector_vector_Point2i([tmp])

        # 在遮罩数组上绘制套索多边形
        for poly in self.lasso_selection.disjoint_selections:
            poly = poly.astype(np.int)
            print poly.shape
            cv.fillPoly(self.mask_img, convert_poly(poly),
                        cv.Scalar(1, 1, 1, 1))
            poly = N - poly  # 绘制对称多边形
            cv.fillPoly(self.mask_img, convert_poly(poly),
                        cv.Scalar(1, 1, 1, 1))

        # 更新遮罩图像
        self.data["mask_img"] = self.mask

        # 更新滤波图像
        data = self.data["filtered_img"]
        data[:] = fft.ifft2(self.fimg * fft.fftshift(self.mask)).real
        self.data["filtered_img"] = data
Beispiel #12
0
class TriangleWave(HasTraits):
    # 指定三角波的最窄和最宽范围,由于Range类型不能将常数和Traits属性名混用
    # 所以定义这两个值不变的Trait属性
    low = Float(0.02)
    hi = Float(1.0)

    # 三角波形的宽度
    wave_width = Range("low", "hi", 0.5)

    # 三角波的顶点C的x轴坐标
    length_c = Range("low", "wave_width", 0.5)

    # 三角波的定点的y轴坐标
    height_c = Float(1.0)

    # FFT计算所使用的取样点数,这里用一个Enum类型的属性以供用户从列表中选择
    fftsize = Enum([(2**x) for x in range(6, 12)])

    # FFT频谱图的x轴上限值
    fft_graph_up_limit = Range(0, 400, 20)

    # 用于显示FFT的结果
    peak_list = Str

    # 采用多少个频率合成三角波
    N = Range(1, 40, 4)

    # 保存绘图数据的对象
    plot_data = Instance(AbstractPlotData)

    # 绘制波形图的容器
    plot_wave = Instance(Component)

    # 绘制FFT频谱图的容器
    plot_fft = Instance(Component)

    # 包括两个绘图的容器
    container = Instance(Component)

    # 设置用户界面的视图, 注意一定要指定窗口的大小,这样绘图容器才能正常初始化
    view = View(HSplit(
        VSplit(
            VGroup(Item("wave_width", editor=scrubber, label="波形宽度"),
                   Item("length_c", editor=scrubber, label="最高点x坐标"),
                   Item("height_c", editor=scrubber, label="最高点y坐标"),
                   Item("fft_graph_up_limit", editor=scrubber, label="频谱图范围"),
                   Item("fftsize", label="FFT点数"), Item("N", label="合成波频率数")),
            Item("peak_list",
                 style="custom",
                 show_label=False,
                 width=100,
                 height=250)),
        VGroup(Item("container",
                    editor=ComponentEditor(size=(600, 300)),
                    show_label=False),
               orientation="vertical")),
                resizable=True,
                width=800,
                height=600,
                title="三角波FFT演示")

    # 创建绘图的辅助函数,创建波形图和频谱图有很多类似的地方,因此单独用一个函数以
    # 减少重复代码
    def _create_plot(self, data, name, type="line"):
        p = Plot(self.plot_data)
        p.plot(data, name=name, title=name, type=type)
        p.tools.append(PanTool(p))
        zoom = ZoomTool(component=p, tool_mode="box", always_on=False)
        p.overlays.append(zoom)
        p.title = name
        return p

    def __init__(self):
        # 首先需要调用父类的初始化函数
        super(TriangleWave, self).__init__()

        # 创建绘图数据集,暂时没有数据因此都赋值为空,只是创建几个名字,以供Plot引用
        self.plot_data = ArrayPlotData(x=[], y=[], f=[], p=[], x2=[], y2=[])

        # 创建一个垂直排列的绘图容器,它将频谱图和波形图上下排列
        self.container = VPlotContainer()

        # 创建波形图,波形图绘制两条曲线: 原始波形(x,y)和合成波形(x2,y2)
        self.plot_wave = self._create_plot(("x", "y"), "Triangle Wave")
        self.plot_wave.plot(("x2", "y2"), color="red")

        # 创建频谱图,使用数据集中的f和p
        self.plot_fft = self._create_plot(("f", "p"), "FFT", type="scatter")

        # 将两个绘图容器添加到垂直容器中
        self.container.add(self.plot_wave)
        self.container.add(self.plot_fft)

        # 设置
        self.plot_wave.x_axis.title = "Samples"
        self.plot_fft.x_axis.title = "Frequency pins"
        self.plot_fft.y_axis.title = "(dB)"

        # 改变fftsize为1024,因为Enum的默认缺省值为枚举列表中的第一个值
        self.fftsize = 1024

    # FFT频谱图的x轴上限值的改变事件处理函数,将最新的值赋值给频谱图的响应属性
    def _fft_graph_up_limit_changed(self):
        self.plot_fft.x_axis.mapper.range.high = self.fft_graph_up_limit

    def _N_changed(self):
        self.plot_sin_combine()

    # 多个trait属性的改变事件处理函数相同时,可以用@on_trait_change指定
    @on_trait_change("wave_width, length_c, height_c, fftsize")
    def update_plot(self):
        # 计算三角波
        global y_data
        x_data = np.arange(0, 1.0, 1.0 / self.fftsize)
        func = self.triangle_func()
        # 将func函数的返回值强制转换成float64
        y_data = np.cast["float64"](func(x_data))

        # 计算频谱
        fft_parameters = np.fft.fft(y_data) / len(y_data)

        # 计算各个频率的振幅
        fft_data = np.clip(
            20 * np.log10(np.abs(fft_parameters))[:self.fftsize / 2 + 1], -120,
            120)

        # 将计算的结果写进数据集
        self.plot_data.set_data("x", np.arange(0, self.fftsize))  # x坐标为取样点
        self.plot_data.set_data("y", y_data)
        self.plot_data.set_data("f", np.arange(0, len(fft_data)))  # x坐标为频率编号
        self.plot_data.set_data("p", fft_data)

        # 合成波的x坐标为取样点,显示2个周期
        self.plot_data.set_data("x2", np.arange(0, 2 * self.fftsize))

        # 更新频谱图x轴上限
        self._fft_graph_up_limit_changed()

        # 将振幅大于-80dB的频率输出
        peak_index = (fft_data > -80)
        peak_value = fft_data[peak_index][:20]
        result = []
        for f, v in zip(np.flatnonzero(peak_index), peak_value):
            result.append("%s : %s" % (f, v))
        self.peak_list = "\n".join(result)

        # 保存现在的fft计算结果,并计算正弦合成波
        self.fft_parameters = fft_parameters
        self.plot_sin_combine()

    # 计算正弦合成波,计算2个周期
    def plot_sin_combine(self):
        index, data = fft_combine(self.fft_parameters, self.N, 2)
        self.plot_data.set_data("y2", data)

    # 返回一个ufunc计算指定参数的三角波
    def triangle_func(self):
        c = self.wave_width
        c0 = self.length_c
        hc = self.height_c

        def trifunc(x):
            x = x - int(x)  # 三角波的周期为1,因此只取x坐标的小数部分进行计算
            if x >= c: r = 0.0
            elif x < c0: r = x / c0 * hc
            else: r = (c - x) / (c - c0) * hc
            return r

        # 用trifunc函数创建一个ufunc函数,可以直接对数组进行计算, 不过通过此函数
        # 计算得到的是一个Object数组,需要进行类型转换
        return np.frompyfunc(trifunc, 1, 1)
Beispiel #13
0
class EqualizerDesigner(HasTraits):
    '''均衡器设计器的主界面'''

    equalizers = Instance(Equalizers)

    # 保存绘图数据的对象
    plot_data = Instance(AbstractPlotData)

    # 绘制波形图的容器
    container = Instance(Component)

    plot_gain = Instance(Component)
    plot_phase = Instance(Component)
    save_button = Button("Save")
    load_button = Button("Load")
    export_button = Button("Export")

    view = View(VGroup(
        HGroup(Item("load_button"),
               Item("save_button"),
               Item("export_button"),
               show_labels=False),
        HSplit(
            VGroup(
                Item("equalizers", style="custom", show_label=False),
                show_border=True,
            ),
            Item("container",
                 editor=ComponentEditor(size=(800, 300)),
                 show_label=False),
        )),
                resizable=True,
                width=800,
                height=500,
                title="Equalizer Designer")

    def _create_plot(self, data, name, type="line"):
        p = Plot(self.plot_data)
        p.plot(data, name=name, title=name, type=type)
        p.tools.append(PanTool(p))
        zoom = ZoomTool(component=p, tool_mode="box", always_on=False)
        p.overlays.append(zoom)
        p.title = name
        p.index_scale = "log"
        return p

    def __init__(self):
        super(EqualizerDesigner, self).__init__()
        self.plot_data = ArrayPlotData(f=FREQS, gain=[], phase=[])
        self.plot_gain = self._create_plot(("f", "gain"), "Gain(dB)")
        self.plot_phase = self._create_plot(("f", "phase"), "Phase(degree)")
        self.container = VPlotContainer()
        self.container.add(self.plot_phase)
        self.container.add(self.plot_gain)
        self.plot_gain.padding_bottom = 20
        self.plot_phase.padding_top = 20

    def _equalizers_default(self):
        return Equalizers()

    @on_trait_change("equalizers.h")
    def redraw(self):
        gain = 20 * np.log10(np.abs(self.equalizers.h))
        phase = np.angle(self.equalizers.h, deg=1)
        self.plot_data.set_data("gain", gain)
        self.plot_data.set_data("phase", phase)

    def _save_button_fired(self):
        dialog = FileDialog(action="save as", wildcard='EQ files (*.eq)|*.eq')
        result = dialog.open()
        if result == OK:
            f = file(dialog.path, "wb")
            pickle.dump(self.equalizers, f)
            f.close()

    def _load_button_fired(self):
        dialog = FileDialog(action="open", wildcard='EQ files (*.eq)|*.eq')
        result = dialog.open()
        if result == OK:
            f = file(dialog.path, "rb")
            self.equalizers = pickle.load(f)
            f.close()

    def _export_button_fired(self):
        dialog = FileDialog(action="save as", wildcard='c files (*.c)|*.c')
        result = dialog.open()
        if result == OK:
            self.equalizers.export(dialog.path)
Beispiel #14
0
class Figure(HasTraits):
    pd = Instance(ArrayPlotData, transient = True)
    plot = Instance(Component,transient = True)
    process_selection = Function(transient = True)
    file = File('/home/andrej/Pictures/img_4406.jpg')
    
    traits_view = View(
                    Group(
                        Item('plot', editor=ComponentEditor(size=size,
                                                            bgcolor=bg_color), 
                             show_label=False),
                        orientation = "vertical"),
                    resizable = True
                    )
    
    def __init__(self,**kwds):
        super(Figure,self).__init__(**kwds)
        self.pd = self._pd_default()
        self.plot = self._plot_default()

    def _process_selection_default(self):
        def process(point0, point1):
            print('selection', point0, point1)
        return process
        
    def _pd_default(self):
        image = zeros(shape = (300,400))        
        pd = ArrayPlotData()
        pd.set_data("imagedata", toRGB(image))   
        return pd
        
    def _plot_default(self):
        return self._create_plot_component()

    def _create_plot_component(self):
        pd = self.pd

    # Create the plot
        plot = Plot(pd, default_origin="top left",orientation="h")
        shape = pd.get_data('imagedata').shape
        plot.aspect_ratio = float(shape[1]) / shape[0]
        plot.x_axis.orientation = "top"
        #plot.y_axis.orientation = "top"
        #img_plot = plot.img_plot("imagedata",colormap = jet)[0]
        img_plot = plot.img_plot("imagedata",name = 'image', colormap = jet)[0]
        
    # Tweak some of the plot properties
        #plot.bgcolor = "white"
        plot.bgcolor = bg_color
        
    # Attach some tools to the plot
        plot.tools.append(PanTool(plot,constrain_key="shift", drag_button = 'right'))
        printer = DataPrinter(component=plot, process = self.process_selection)
        plot.tools.append(printer)
        plot.overlays.append(ZoomTool(component=plot, 
                                  tool_mode="box", always_on=False))
        #plot.title = 'Default image'
        
        imgtool = ImageInspectorTool(img_plot)
        img_plot.tools.append(imgtool)
        plot.overlays.append(ImageInspectorOverlay(component=img_plot, 
                                               image_inspector=imgtool))
        return plot
    
    def _file_changed(self, new):
        image = ImageData.fromfile(new)
        self.update_image(image.data)
    
    def update_image(self,data):
        image = toRGB(data)
        shape = image.shape
        self.pd.set_data("imagedata", image) 
        self.plot.aspect_ratio = float(shape[1]) / shape[0]  
        self.plot.delplot('image')
        img_plot = self.plot.img_plot("imagedata",name = 'image', colormap = jet)[0]
        imgtool = ImageInspectorTool(img_plot)
        img_plot.tools.append(imgtool)
        self.plot.overlays.pop()
        self.plot.overlays.append(ImageInspectorOverlay(component=img_plot, 
                                               image_inspector=imgtool))

        
        #self.plot.plot('rectangle1',)
        self.plot.request_redraw()
        
        
    def plot_data(self, x, y, name = 'data 0', color = 'black'):
        xname = 'x_' + name
        yname = 'y_' + name
        self.pd.set_data(xname,x)
        self.pd.set_data(yname,y)
        self.del_plot(name)
        self.plot.plot((xname,yname), name = name, color = color)
        self.plot.request_redraw()
    
    def del_plot(self, name):
        try:
            self.plot.delplot(name)
        except:
            pass        
Beispiel #15
0
class Demo(HasTraits):
    plot = Instance(Component)
    fileName = "clusters.cpickle"
    case = List(UncertaintyValue)

    cases = {}
    defaultCase = []

    # Attributes to use for the plot view.
    size = (400, 1600)

    traits_view = View(Group(
        Group(Item('plot', editor=ComponentEditor(size=size),
                   show_label=False),
              orientation="vertical",
              show_border=True,
              scrollable=True),
        Group(Item('case',
                   editor=TabularEditor(adapter=CaseAdapter(can_edit=False)),
                   show_label=False),
              orientation="vertical",
              show_border=True),
        layout='split',
        orientation='horizontal'),
                       title='Interactive Lines',
                       resizable=True)

    def setFileName(self, newName):
        self.fileName = newName

    def _update_case(self, name):

        if name:
            self.case = self.cases.get(name)

        else:
            self.case = self.defaultCase

    def _plot_default(self):

        #load the data to visualize.
        # it is a list of data in the 'results' format, each belonging to a cluster - gonenc
        resultsList = cPickle.load(open(self.fileName, 'r'))

        #get the names of the outcomes to display
        outcome = []
        for entry in resultsList:
            a = entry[0][1].keys()
            outcome.append(a[0])

#        outcome = resultsList[0][0][1].keys()

# pop the time axis from the list of outcomes
#        outcome.pop(outcome.index('TIME'))
        x = resultsList[0][0][1]['TIME']

        # the list and number of features (clustering related) stored regarding each run
        features = resultsList[0][0][0][0].keys()
        noFeatures = len(features)

        # Iterate over each cluster to prepare the cases corresponding to indivisdual runs in
        # each cluster plot. Each case is labeled as, e.g., y1-2 (3rd run in the 2nd cluster) - gonenc
        for c, results in enumerate(resultsList):
            for j, aCase in enumerate(results):
                aCase = [
                    UncertaintyValue(name=key, value=value)
                    for key, value in aCase[0][0].items()
                ]
                self.cases['y' + str(c) + '-' + str(j)] = aCase


#        for j, aCase in enumerate(results):
#            aCase = [UncertaintyValue(name="blaat", value=aCase[0][0])]
#            self.cases['y'+str(j)] = aCase

#make an empty case for default.
#if you have multiple datafields associated with a run, iterate over
#the keys of a dictionary of a case, instead of over lenght(2)
        case = []
        for i in range(noFeatures):
            case.append(UncertaintyValue(name='Default', value='None'))
        self.case = case
        self.defaultCase = case

        # Create some x-y data series to plot
        pds = []
        # enumerate over the results of all clusters
        for c, results in enumerate(resultsList):
            pd = ArrayPlotData(index=x)
            for j in range(len(results)):
                data = np.array(results[j][1].get(outcome[c]))
                print "y" + str(c) + '-' + str(j)
                pd.set_data("y" + str(c) + '-' + str(j), data)
            pds.append(pd)

        # Create a container and add our plots
        container = GridContainer(bgcolor="lightgray",
                                  use_backbuffer=True,
                                  shape=(len(resultsList), 1))

        #plot data
        tools = []
        for c, results in enumerate(resultsList):
            pd1 = pds[c]

            # Create some line plots of some of the data
            plot = Plot(pd1,
                        title='Cluster ' + str(c),
                        border_visible=True,
                        border_width=1)
            plot.legend.visible = False

            #plot the results
            for i in range(len(results)):
                plotvalue = "y" + str(c) + '-' + str(i)
                print plotvalue
                color = colors[i % len(colors)]
                plot.plot(("index", plotvalue), name=plotvalue, color=color)

            #make sure that the time axis runs in the right direction
            for value in plot.plots.values():
                for entry in value:
                    entry.index.sort_order = 'ascending'

            # Attach the selector tools to the plot
            selectorTool1 = LineSelectorTool(component=plot)
            plot.tools.append(selectorTool1)
            tools.append(selectorTool1)

            # Attach some tools to the plot
            plot.tools.append(PanTool(plot))
            zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
            plot.overlays.append(zoom)

            container.add(plot)

        #make sure the selector tools knows the main screen
        for tool in tools:
            tool._demo = self

        return container
Beispiel #16
0
class VariableMeshPannerView(HasTraits):

    plot = Instance(Plot)
    spawn_zoom = Button
    vm_plot = Instance(VMImagePlot)
    use_tools = Bool(True)
    full_container = Instance(HPlotContainer)
    container = Instance(OverlayPlotContainer)

    traits_view = View(
        Group(Item('full_container',
                   editor=ComponentEditor(size=(512, 512)),
                   show_label=False),
              Item('field', show_label=False),
              orientation="vertical"),
        width=800,
        height=800,
        resizable=True,
        title="Pan and Scan",
    )

    def _vm_plot_default(self):
        return VMImagePlot(panner=self.panner)

    def __init__(self, **kwargs):
        super(VariableMeshPannerView, self).__init__(**kwargs)
        # Create the plot
        self.add_trait("field", DelegatesTo("vm_plot"))

        plot = self.vm_plot.plot
        img_plot = self.vm_plot.img_plot

        if self.use_tools:
            plot.tools.append(PanTool(img_plot))
            zoom = ZoomTool(component=img_plot,
                            tool_mode="box",
                            always_on=False)
            plot.overlays.append(zoom)
            imgtool = ImageInspectorTool(img_plot)
            img_plot.tools.append(imgtool)
            overlay = ImageInspectorOverlay(component=img_plot,
                                            image_inspector=imgtool,
                                            bgcolor="white",
                                            border_visible=True)
            img_plot.overlays.append(overlay)

        image_value_range = DataRange1D(self.vm_plot.fid)
        cbar_index_mapper = LinearMapper(range=image_value_range)
        self.colorbar = ColorBar(index_mapper=cbar_index_mapper,
                                 plot=img_plot,
                                 padding_right=40,
                                 resizable='v',
                                 width=30)

        self.colorbar.tools.append(
            PanTool(self.colorbar, constrain_direction="y", constrain=True))
        zoom_overlay = ZoomTool(self.colorbar,
                                axis="index",
                                tool_mode="range",
                                always_on=True,
                                drag_button="right")
        self.colorbar.overlays.append(zoom_overlay)

        # create a range selection for the colorbar
        range_selection = RangeSelection(component=self.colorbar)
        self.colorbar.tools.append(range_selection)
        self.colorbar.overlays.append(
            RangeSelectionOverlay(component=self.colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray"))

        # we also want to the range selection to inform the cmap plot of
        # the selection, so set that up as well
        range_selection.listeners.append(img_plot)

        self.full_container = HPlotContainer(padding=30)
        self.container = OverlayPlotContainer(padding=0)
        self.full_container.add(self.colorbar)
        self.full_container.add(self.container)
        self.container.add(self.vm_plot.plot)
Beispiel #17
0
class Figure(HasTraits):
    """For displaying images, note that grayscale uint8 does not seem to work because of a bug in numpy 1.4.x
    graysale float works!
    
    By default grayscale images are displayed as a RGB grayscale (much faster)
    
    >>> from scipy.misc.pilutil import imread
    >>> im = imread('testdata/noise.png')
    >>> f = Figure() 
    >>> f.plot_image(im)
    >>> result = f.configure_traits()
    >>> f.plot_image(im[:,:,0])
    >>> result = f.configure_traits()    
    

    """
    pd = Instance(ArrayPlotData, transient=True)
    plot = Instance(Plot, transient=True)
    process_selection = Function(transient=True)

    #: defines which color map to use for grayscale images ('gray' or 'color')
    colormap = Enum('gray', 'color')

    traits_view = View(Group(Item('plot',
                                  editor=ComponentEditor(size=size,
                                                         bgcolor=bg_color),
                                  show_label=False),
                             orientation="vertical"),
                       resizable=True)

    def __init__(self, **kwds):
        super(Figure, self).__init__(**kwds)
        self.pd = self._pd_default()
        self.plot = self._plot_default()

    def _process_selection_default(self):
        def process(point0, point1):
            print('selection', point0, point1)

        return process

    def _pd_default(self):
        image = numpy.array([])
        pd = ArrayPlotData()
        pd.set_data("imagedata", image)
        return pd

    def _plot_default(self):
        return self._create_plot_component()

    def _create_plot_component(self):

        # Create the plot
        pd = self.pd
        plot = Plot(pd, default_origin="top left", orientation="h")
        plot.x_axis.orientation = "top"
        plot.padding = 20
        #plot.y_axis.orientation = "top"

        # Tweak some of the plot properties
        #plot.bgcolor = "white"
        plot.bgcolor = bg_color

        # Attach some tools to the plot

        # plot.tools.append(PanTool(plot,constrain_key="shift", drag_button = 'right'))
        printer = DataPrinter(component=plot, process=self.process_selection)
        plot.tools.append(printer)
        plot.overlays.append(
            ZoomTool(component=plot, tool_mode="box", always_on=False))
        return plot

    def plot_image(self, data):
        """Plots image from a given array
        
        :param array data:
            Input image array
        """
        self._plot_image(self.plot, data)

    def _plot_image(self, plot, data):
        if self.colormap == 'gray':
            image = to_RGB(data)
        else:
            image = data
        self.pd.set_data("imagedata", image)
        plot.aspect_ratio = float(image.shape[1]) / image.shape[0]
        if not plot.plots:
            img_plot = plot.img_plot("imagedata", name='image')[0]

        img_plot = plot.plots['image'][0]
        img_plot.edit_traits()
        #plot.request_redraw()
        plot.redraw()

    def update_image(self, data):
        import warnings
        warnings.warn('Use plot_image insteead!', DeprecationWarning, 2)
        self.plot_image(data)

    def plot_data(self, x, y, name='data 0', color='black'):
        """Plots additional line data on top of the image.
        
        :param x:
            x data
        :param y:
            y data
        :param name:
            name of a plot, you can call :meth:`del_plot` to delete it later
        :param color:
            color of plotted data
        """
        self._plot_data(self.plot, x, y, name, color)

    def _plot_data(self, plot, x, y, name='data 0', color='black'):
        xname = 'x_' + name
        yname = 'y_' + name
        self.pd.set_data(xname, x)
        self.pd.set_data(yname, y)
        self._del_plot(plot, name)
        plot.plot((xname, yname), name=name, color=color)
        plot.request_redraw()

    def del_plot(self, name):
        """Delete a plot 
        """
        self._del_plot(self.plot, name)

    def _del_plot(self, plot, name):
        try:
            plot.delplot(name)
        except:
            pass
Beispiel #18
0
#    else:
    return image

#===============================================================================
# Attributes to use for the plot view.
size = (600, 600)
title="Simple image plot"
bg_color="lightgray"
       
#===============================================================================
# # Figure class for inclusion in other traits
#===============================================================================

figure_view = View(
                    Group('file',
                        Item('plot', editor=ComponentEditor(size=size,
                                                            bgcolor=bg_color), 
                             show_label=False),
                        orientation = "vertical"),
                    resizable=True, title=title
                    )


class Figure(HasTraits):
    pd = Instance(ArrayPlotData, transient = True)
    plot = Instance(Component,transient = True)
    process_selection = Function(transient = True)
    file = File('/home/andrej/Pictures/img_4406.jpg')
    
    traits_view = View(
                    Group(
                        Item('plot', editor=ComponentEditor(size=size,
Beispiel #19
0
class FigureInspector(Figure):
    """Same as :class:`Figure` in addition it displays horizontal and vertical line inspectors
    """
    container = Instance(Component, transient=True)
    h_plot = Instance(Plot)
    v_plot = Instance(Plot)

    traits_view = View(Group(Item('container',
                                  editor=ComponentEditor(size=size,
                                                         bgcolor=bg_color),
                                  show_label=False),
                             orientation="vertical"),
                       resizable=True)

    def _plot_image(self, plot, data):
        if self.colormap == 'gray':
            image = to_RGB(data)
        else:
            image = data
        self.pd.set_data("imagedata", image)
        plot.aspect_ratio = float(image.shape[1]) / image.shape[0]
        if not plot.plots:
            img_plot = plot.img_plot("imagedata", name='image')[0]

            img_plot.index.on_trait_change(self._metadata_changed,
                                           "metadata_changed")

            img_plot.overlays.append(
                LineInspector(component=img_plot,
                              axis='index_x',
                              inspect_mode="indexed",
                              write_metadata=True,
                              is_listener=False,
                              color="white"))
            img_plot.overlays.append(
                LineInspector(component=img_plot,
                              axis='index_y',
                              inspect_mode="indexed",
                              write_metadata=True,
                              color="white",
                              is_listener=False))

        img_plot = plot.plots['image'][0]
        shape = image.shape
        self.h_plot.index_range = img_plot.index_range.x_range
        self.v_plot.index_range = img_plot.index_range.y_range
        self.pd.set_data('h_index', numpy.arange(shape[1]))
        self.pd.set_data('v_index', numpy.arange(shape[0]))
        self.plot.request_redraw()
        self.container.request_redraw()

    def _pd_default(self):
        image = ones(shape=(300, 400))
        pd = ArrayPlotData()
        pd.set_data("imagedata", image)
        pd.set_data('h_index', numpy.arange(400))
        pd.set_data('h_value', numpy.ones((400, )))
        pd.set_data('v_index', numpy.arange(300))
        pd.set_data('v_value', numpy.ones((300, )))
        return pd

    def _h_plot_default(self):
        plot = Plot(self.pd, resizable="h")
        plot.height = 100
        plot.padding = 20
        plot.plot(("h_index", "h_value"))
        return plot

    def _v_plot_default(self):
        plot = Plot(self.pd,
                    orientation="v",
                    resizable="v",
                    padding=20,
                    padding_bottom=160,
                    default_origin="top left")
        plot.height = 600
        plot.width = 100
        plot.plot(("v_index", "v_value"))
        return plot

    def _container_default(self):
        #image_container = OverlayPlotContainer(padding=20,
        #                                         use_backbuffer=True,
        #                                         unified_draw=True)
        #image_container.add(self.plot)
        container = HPlotContainer(padding=40,
                                   fill_padding=True,
                                   bgcolor="white",
                                   use_backbuffer=False)
        inner_cont = VPlotContainer(padding=0, use_backbuffer=True)
        #        container = HPlotContainer(bgcolor = "white", use_backbuffer=False)
        #        inner_cont = VPlotContainer(use_backbuffer=True)
        inner_cont.add(self.h_plot)
        inner_cont.add(self.plot)
        container.add(inner_cont)
        container.add(self.v_plot)
        return container

    def _metadata_changed(self, old, new):
        """ This function takes out a cross section from the image data, based
        on the line inspector selections, and updates the line and scatter 
        plots."""
        img_plot = self.plot.plots['image'][0]
        image_index = img_plot.index
        image_value = img_plot.value
        if "selections" in image_index.metadata:
            x_ndx, y_ndx = image_index.metadata["selections"]
            if y_ndx and x_ndx:
                h_slice = image_value.data[y_ndx, :, 0]
                self.pd.set_data("h_value", h_slice)
                self.h_plot.value_range.low = h_slice.min()
                self.h_plot.value_range.high = h_slice.max()
                v_slice = image_value.data[:, x_ndx, 0]
                self.pd.set_data("v_value", v_slice)
                self.v_plot.value_range.low = v_slice.min()
                self.v_plot.value_range.high = v_slice.max()
                xdata, ydata = image_index.get_data()
                xdata, ydata = xdata.get_data(), ydata.get_data()

        else:
            self.pd.set_data("h_value", numpy.array([]))
            self.pd.set_data("v_value", numpy.array([]))
Beispiel #20
0
        return image


#===============================================================================
# Attributes to use for the plot view.
size = (500, 500)
title = "Simple image plot"
bg_color = "lightgray"

#===============================================================================
# # Figure class for inclusion in other traits
#===============================================================================

figure_view = View(Group('file',
                         Item('plot',
                              editor=ComponentEditor(size=size,
                                                     bgcolor=bg_color),
                              show_label=False),
                         orientation="vertical"),
                   resizable=True,
                   title=title)


class Figure(HasTraits):
    """For displaying images, note that grayscale uint8 does not seem to work because of a bug in numpy 1.4.x
    graysale float works!
    
    By default grayscale images are displayed as a RGB grayscale (much faster)
    
    >>> from scipy.misc.pilutil import imread
    >>> im = imread('testdata/noise.png')
    >>> f = Figure() 
Beispiel #21
0
class InPaintDemo(HasTraits):
    plot = Instance(Plot)
    painter = Instance(CirclePainter)
    r = Range(2.0, 20.0, 10.0)  # inpaint的半径参数
    method = Enum("INPAINT_NS", "INPAINT_TELEA")  # inpaint的算法
    show_mask = Bool(False)  # 是否显示选区
    clear_mask = Button("清除选区")
    apply = Button("保存结果")

    view = View(VGroup(
        VGroup(
            Item("object.painter.r", label="画笔半径"), Item("r",
                                                         label="inpaint半径"),
            HGroup(
                Item("method", label="inpaint算法"),
                Item("show_mask", label="显示选区"),
                Item("clear_mask", show_label=False),
                Item("apply", show_label=False),
            )),
        Item("plot", editor=ComponentEditor(), show_label=False),
    ),
                title="inpaint Demo控制面板",
                width=500,
                height=450,
                resizable=True)

    def __init__(self, *args, **kwargs):
        super(InPaintDemo, self).__init__(*args, **kwargs)
        self.img = cv.imread("stuff.jpg")  # 原始图像
        self.img2 = self.img.clone()  # inpaint效果预览图像
        self.mask = cv.Mat(self.img.size(), cv.CV_8UC1)  # 储存选区的图像
        self.mask[:] = 0
        self.data = ArrayPlotData(img=self.img[:, :, ::-1])
        self.plot = Plot(self.data,
                         padding=10,
                         aspect_ratio=float(self.img.size().width) /
                         self.img.size().height)
        self.plot.x_axis.visible = False
        self.plot.y_axis.visible = False
        imgplot = self.plot.img_plot("img", origin="top left")[0]
        self.painter = CirclePainter(component=imgplot)
        imgplot.overlays.append(self.painter)

    @on_trait_change("r,method")
    def inpaint(self):
        cv.inpaint(self.img, self.mask, self.img2, self.r,
                   getattr(cv, self.method))
        self.draw()

    @on_trait_change("painter:updated")
    def painter_updated(self):
        for _, _, x, y in self.painter.track:
            # 在储存选区的mask上绘制圆形
            cv.circle(self.mask,
                      cv.Point(int(x), int(y)),
                      int(self.painter.r),
                      cv.Scalar(255, 255, 255, 255),
                      thickness=-1)  # 宽度为负表示填充圆形
        self.inpaint()
        self.painter.track = []
        self.painter.request_redraw()

    def _clear_mask_fired(self):
        self.mask[:] = 0
        self.inpaint()

    def _apply_fired(self):
        """保存inpaint的处理结果,并清除选区"""
        self.img[:] = self.img2[:]
        self._clear_mask_fired()

    @on_trait_change("show_mask")
    def draw(self):
        if self.show_mask:
            data = self.img[:, :, ::-1].copy()
            data[self.mask[:] > 0] = 255
            self.data["img"] = data
        else:
            self.data["img"] = self.img2[:, :, ::-1]
Beispiel #22
0
class FitGui(HasTraits):
    """
    This class represents the fitgui application state.
    """

    plot = Instance(Plot)
    colorbar = Instance(ColorBar)
    plotcontainer = Instance(HPlotContainer)
    tmodel = Instance(TraitedModel,allow_none=False)
    nomodel = Property
    newmodel = Button('New Model...')
    fitmodel = Button('Fit Model')
    showerror = Button('Fit Error')
    updatemodelplot = Button('Update Model Plot')
    autoupdate = Bool(True)
    data = Array(dtype=float,shape=(2,None))
    weights = Array
    weighttype = Enum(('custom','equal','lin bins','log bins'))
    weightsvary = Property(Bool)
    weights0rem = Bool(True)
    modelselector = NewModelSelector
    ytype = Enum(('data and model','residuals'))

    zoomtool = Instance(ZoomTool)
    pantool = Instance(PanTool)

    scattertool = Enum(None,'clicktoggle','clicksingle','clickimmediate','lassoadd','lassoremove','lassoinvert')
    selectedi = Property #indecies of the selected objects
    weightchangesel = Button('Set Selection To')
    weightchangeto = Float(1.0)
    delsel = Button('Delete Selected')
    unselectonaction = Bool(True)
    clearsel = Button('Clear Selections')
    lastselaction = Str('None')

    datasymb = Button('Data Symbol...')
    modline = Button('Model Line...')

    savews = Button('Save Weights')
    loadws = Button('Load Weights')
    _savedws = Array

    plotname = Property
    updatestats = Event
    chi2 = Property(Float,depends_on='updatestats')
    chi2r = Property(Float,depends_on='updatestats')


    nmod = Int(1024)
    #modelpanel = View(Label('empty'),kind='subpanel',title='model editor')
    modelpanel = View

    panel_view = View(VGroup(
                       Item('plot', editor=ComponentEditor(),show_label=False),
                       HGroup(Item('tmodel.modelname',show_label=False,style='readonly'),
                              Item('nmod',label='Number of model points'),
                              Item('updatemodelplot',show_label=False,enabled_when='not autoupdate'),
                              Item('autoupdate',label='Auto?'))
                      ),
                    title='Model Data Fitter'
                    )


    selection_view = View(Group(
                           Item('scattertool',label='Selection Mode',
                                 editor=EnumEditor(values={None:'1:No Selection',
                                                           'clicktoggle':'3:Toggle Select',
                                                           'clicksingle':'2:Single Select',
                                                           'clickimmediate':'7:Immediate',
                                                           'lassoadd':'4:Add with Lasso',
                                                           'lassoremove':'5:Remove with Lasso',
                                                           'lassoinvert':'6:Invert with Lasso'})),
                           Item('unselectonaction',label='Clear Selection on Action?'),
                           Item('clearsel',show_label=False),
                           Item('weightchangesel',show_label=False),
                           Item('weightchangeto',label='Weight'),
                           Item('delsel',show_label=False)
                         ),title='Selection Options')

    traits_view = View(VGroup(
                        HGroup(Item('object.plot.index_scale',label='x-scaling',
                                    enabled_when='object.plot.index_mapper.range.low>0 or object.plot.index_scale=="log"'),
                              spring,
                              Item('ytype',label='y-data'),
                              Item('object.plot.value_scale',label='y-scaling',
                                   enabled_when='object.plot.value_mapper.range.low>0 or object.plot.value_scale=="log"')
                              ),
                       Item('plotcontainer', editor=ComponentEditor(),show_label=False),
                       HGroup(VGroup(HGroup(Item('weighttype',label='Weights:'),
                                            Item('savews',show_label=False),
                                            Item('loadws',enabled_when='_savedws',show_label=False)),
                                Item('weights0rem',label='Remove 0-weight points for fit?'),
                                HGroup(Item('newmodel',show_label=False),
                                       Item('fitmodel',show_label=False),
                                       Item('showerror',show_label=False,enabled_when='tmodel.lastfitfailure'),
                                       VGroup(Item('chi2',label='Chi2:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None'),
                                             Item('chi2r',label='reduced:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None'))
                                       )#Item('selbutton',show_label=False))
                              ,springy=False),spring,
                              VGroup(HGroup(Item('autoupdate',label='Auto?'),
                              Item('updatemodelplot',show_label=False,enabled_when='not autoupdate')),
                              Item('nmod',label='Nmodel'),
                              HGroup(Item('datasymb',show_label=False),Item('modline',show_label=False)),springy=False),springy=True),
                       '_',
                       HGroup(Item('scattertool',label='Selection Mode',
                                 editor=EnumEditor(values={None:'1:No Selection',
                                                           'clicktoggle':'3:Toggle Select',
                                                           'clicksingle':'2:Single Select',
                                                           'clickimmediate':'7:Immediate',
                                                           'lassoadd':'4:Add with Lasso',
                                                           'lassoremove':'5:Remove with Lasso',
                                                           'lassoinvert':'6:Invert with Lasso'})),
                           Item('unselectonaction',label='Clear Selection on Action?'),
                           Item('clearsel',show_label=False),
                           Item('weightchangesel',show_label=False),
                           Item('weightchangeto',label='Weight'),
                           Item('delsel',show_label=False),
                         ),#layout='flow'),
                       Item('tmodel',show_label=False,style='custom',editor=InstanceEditor(kind='subpanel'))
                      ),
                    handler=FGHandler(),
                    resizable=True,
                    title='Data Fitting',
                    buttons=['OK','Cancel'],
                    width=700,
                    height=900
                    )


    def __init__(self,xdata=None,ydata=None,weights=None,model=None,
                 include_models=None,exclude_models=None,fittype=None,**traits):
        """

        :param xdata: the first dimension of the data to be fit
        :type xdata: array-like
        :param ydata: the second dimension of the data to be fit
        :type ydata: array-like
        :param weights:
            The weights to apply to the data. Statistically interpreted as inverse
            errors (*not* inverse variance). May be any of the following forms:

            * None for equal weights
            * an array of points that must match `ydata`
            * a 2-sequence of arrays (xierr,yierr) such that xierr matches the
              `xdata` and yierr matches `ydata`
            * a function called as f(params) that returns an array of weights
              that match one of the above two conditions

        :param model: the initial model to use to fit this data
        :type model:
            None, string, or :class:`pymodelfit.core.FunctionModel1D`
            instance.
        :param include_models:
            With `exclude_models`, specifies which models should be available in
            the "new model" dialog (see `models.list_models` for syntax).
        :param exclude_models:
            With `include_models`, specifies which models should be available in
            the "new model" dialog (see `models.list_models` for syntax).
        :param fittype:
            The fitting technique for the initial fit (see
            :class:`pymodelfit.core.FunctionModel`).
        :type fittype: string

        kwargs are passed in as any additional traits to apply to the
        application.

        """

        self.modelpanel = View(Label('empty'),kind='subpanel',title='model editor')

        self.tmodel = TraitedModel(model)

        if model is not None and fittype is not None:
            self.tmodel.model.fittype = fittype

        if xdata is None or ydata is None:
            if not hasattr(self.tmodel.model,'data') or self.tmodel.model.data is None:
                raise ValueError('data not provided and no data in model')
            if xdata is None:
                xdata = self.tmodel.model.data[0]
            if ydata is None:
                ydata = self.tmodel.model.data[1]
            if weights is None:
                weights = self.tmodel.model.data[2]

        self.on_trait_change(self._paramsChanged,'tmodel.paramchange')

        self.modelselector = NewModelSelector(include_models,exclude_models)

        self.data = [xdata,ydata]


        if weights is None:
            self.weights = np.ones_like(xdata)
            self.weighttype = 'equal'
        else:
            self.weights = np.array(weights,copy=True)
            self.savews = True

        weights1d = self.weights
        while len(weights1d.shape)>1:
            weights1d = np.sum(weights1d**2,axis=0)

        pd = ArrayPlotData(xdata=self.data[0],ydata=self.data[1],weights=weights1d)
        self.plot = plot = Plot(pd,resizable='hv')

        self.scatter = plot.plot(('xdata','ydata','weights'),name='data',
                         color_mapper=_cmapblack if self.weights0rem else _cmap,
                         type='cmap_scatter', marker='circle')[0]

        self.errorplots = None

        if not isinstance(model,FunctionModel1D):
            self.fitmodel = True

        self.updatemodelplot = False #force plot update - generates xmod and ymod
        plot.plot(('xmod','ymod'),name='model',type='line',line_style='dash',color='black',line_width=2)
        del plot.x_mapper.range.sources[-1]  #remove the line plot from the x_mapper source so only the data is tied to the scaling

        self.on_trait_change(self._rangeChanged,'plot.index_mapper.range.updated')

        self.pantool = PanTool(plot,drag_button='left')
        plot.tools.append(self.pantool)
        self.zoomtool = ZoomTool(plot)
        self.zoomtool.prev_state_key = KeySpec('a')
        self.zoomtool.next_state_key = KeySpec('s')
        plot.overlays.append(self.zoomtool)

        self.scattertool = None
        self.scatter.overlays.append(ScatterInspectorOverlay(self.scatter,
                        hover_color = "black",
                        selection_color="black",
                        selection_outline_color="red",
                        selection_line_width=2))


        self.colorbar = colorbar = ColorBar(index_mapper=LinearMapper(range=plot.color_mapper.range),
                                            color_mapper=plot.color_mapper.range,
                                            plot=plot,
                                            orientation='v',
                                            resizable='v',
                                            width = 30,
                                            padding = 5)
        colorbar.padding_top = plot.padding_top
        colorbar.padding_bottom = plot.padding_bottom
        colorbar._axis.title = 'Weights'

        self.plotcontainer = container = HPlotContainer(use_backbuffer=True)
        container.add(plot)
        container.add(colorbar)

        super(FitGui,self).__init__(**traits)

        self.on_trait_change(self._scale_change,'plot.value_scale,plot.index_scale')

        if weights is not None and len(weights)==2:
            self.weightsChanged() #update error bars

    def _weights0rem_changed(self,old,new):
        if new:
            self.plot.color_mapper = _cmapblack(self.plot.color_mapper.range)
        else:
            self.plot.color_mapper = _cmap(self.plot.color_mapper.range)
        self.plot.request_redraw()
#        if old and self.filloverlay in self.plot.overlays:
#            self.plot.overlays.remove(self.filloverlay)
#        if new:
#            self.plot.overlays.append(self.filloverlay)
#        self.plot.request_redraw()

    def _paramsChanged(self):
        self.updatemodelplot = True

    def _nmod_changed(self):
        self.updatemodelplot = True

    def _rangeChanged(self):
        self.updatemodelplot = True

    #@on_trait_change('object.plot.value_scale,object.plot.index_scale',post_init=True)
    def _scale_change(self):
        self.plot.request_redraw()

    def _updatemodelplot_fired(self,new):
        #If the plot has not been generated yet, just skip the update
        if self.plot is None:
            return

        #if False (e.g. button click), update regardless, otherwise check for autoupdate
        if new and not self.autoupdate:
            return

        mod = self.tmodel.model
        if self.ytype == 'data and model':
            if mod:
                #xd = self.data[0]
                #xmod = np.linspace(np.min(xd),np.max(xd),self.nmod)
                xl = self.plot.index_range.low
                xh = self.plot.index_range.high
                if self.plot.index_scale=="log":
                    xmod = np.logspace(np.log10(xl),np.log10(xh),self.nmod)
                else:
                    xmod = np.linspace(xl,xh,self.nmod)
                ymod = self.tmodel.model(xmod)

                self.plot.data.set_data('xmod',xmod)
                self.plot.data.set_data('ymod',ymod)

            else:
                self.plot.data.set_data('xmod',[])
                self.plot.data.set_data('ymod',[])
        elif self.ytype == 'residuals':
            if mod:
                self.plot.data.set_data('xmod',[])
                self.plot.data.set_data('ymod',[])
                #residuals set the ydata instead of setting the model
                res = mod.residuals(*self.data)
                self.plot.data.set_data('ydata',res)
            else:
                self.ytype = 'data and model'
        else:
            assert True,'invalid Enum'


    def _fitmodel_fired(self):
        from warnings import warn

        preaup = self.autoupdate
        try:
            self.autoupdate = False
            xd,yd = self.data
            kwd = {'x':xd,'y':yd}
            if self.weights is not None:
                w = self.weights
                if self.weights0rem:
                    if xd.shape == w.shape:
                        m = w!=0
                        w = w[m]
                        kwd['x'] = kwd['x'][m]
                        kwd['y'] = kwd['y'][m]
                    elif np.any(w==0):
                        warn("can't remove 0-weighted points if weights don't match data")
                kwd['weights'] = w
            self.tmodel.fitdata = kwd
        finally:
            self.autoupdate = preaup

        self.updatemodelplot = True
        self.updatestats = True


#    def _tmodel_changed(self,old,new):
#        #old is only None before it is initialized
#        if new is not None and old is not None and new.model is not None:
#            self.fitmodel = True

    def _newmodel_fired(self,newval):
        from inspect import isclass

        if isinstance(newval,basestring) or isinstance(newval,FunctionModel1D) \
           or (isclass(newval) and issubclass(newval,FunctionModel1D)):
            self.tmodel = TraitedModel(newval)
        else:
            if self.modelselector.edit_traits(kind='modal').result:
                cls = self.modelselector.selectedmodelclass
                if cls is None:
                    self.tmodel = TraitedModel(None)
                elif self.modelselector.isvarargmodel:
                    self.tmodel = TraitedModel(cls(self.modelselector.modelargnum))
                    self.fitmodel = True
                else:
                    self.tmodel = TraitedModel(cls())
                    self.fitmodel = True
            else: #cancelled
                return

    def _showerror_fired(self,evt):
        if self.tmodel.lastfitfailure:
            ex = self.tmodel.lastfitfailure
            dialog = HasTraits(s=ex.__class__.__name__+': '+str(ex))
            view = View(Item('s',style='custom',show_label=False),
                        resizable=True,buttons=['OK'],title='Fitting error message')
            dialog.edit_traits(view=view)

    @cached_property
    def _get_chi2(self):
        try:
            return self.tmodel.model.chi2Data()[0]
        except:
            return 0

    @cached_property
    def _get_chi2r(self):
        try:
            return self.tmodel.model.chi2Data()[1]
        except:
            return 0

    def _get_nomodel(self):
        return self.tmodel.model is None

    def _get_weightsvary(self):
        w = self.weights
        return np.any(w!=w[0])if len(w)>0 else False

    def _get_plotname(self):
        xlabel = self.plot.x_axis.title
        ylabel = self.plot.y_axis.title
        if xlabel == '' and ylabel == '':
            return ''
        else:
            return xlabel+' vs '+ylabel
    def _set_plotname(self,val):
        if isinstance(val,basestring):
            val = val.split('vs')
            if len(val) ==1:
                val = val.split('-')
            val = [v.strip() for v in val]
        self.x_axis.title = val[0]
        self.y_axis.title = val[1]


    #selection-related
    def _scattertool_changed(self,old,new):
        if new == 'No Selection':
            self.plot.tools[0].drag_button='left'
        else:
            self.plot.tools[0].drag_button='right'
        if old is not None and 'lasso' in old:
            if new is not None and 'lasso' in new:
                #connect correct callbacks
                self.lassomode = new.replace('lasso','')
                return
            else:
                #TODO:test
                self.scatter.tools[-1].on_trait_change(self._lasso_handler,
                                            'selection_changed',remove=True)
                del self.scatter.overlays[-1]
                del self.lassomode
        elif old == 'clickimmediate':
            self.scatter.index.on_trait_change(self._immediate_handler,
                                            'metadata_changed',remove=True)

        self.scatter.tools = []
        if new is None:
            pass
        elif 'click' in new:
            smodemap = {'clickimmediate':'single','clicksingle':'single',
                        'clicktoggle':'toggle'}
            self.scatter.tools.append(ScatterInspector(self.scatter,
                                      selection_mode=smodemap[new]))
            if new == 'clickimmediate':
                self.clearsel = True
                self.scatter.index.on_trait_change(self._immediate_handler,
                                                    'metadata_changed')
        elif 'lasso' in new:
            lasso_selection = LassoSelection(component=self.scatter,
                                    selection_datasource=self.scatter.index)
            self.scatter.tools.append(lasso_selection)
            lasso_overlay = LassoOverlay(lasso_selection=lasso_selection,
                                         component=self.scatter)
            self.scatter.overlays.append(lasso_overlay)
            self.lassomode = new.replace('lasso','')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'selection_changed')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'selection_completed')
            lasso_selection.on_trait_change(self._lasso_handler,
                                            'updated')
        else:
            raise TraitsError('invalid scattertool value')

    def _weightchangesel_fired(self):
        self.weights[self.selectedi] = self.weightchangeto
        if self.unselectonaction:
            self.clearsel = True

        self._sel_alter_weights()
        self.lastselaction = 'weightchangesel'

    def _delsel_fired(self):
        self.weights[self.selectedi] = 0
        if self.unselectonaction:
            self.clearsel = True

        self._sel_alter_weights()
        self.lastselaction = 'delsel'

    def _sel_alter_weights(self):
        if self.weighttype != 'custom':
            self._customweights = self.weights
            self.weighttype = 'custom'
        self.weightsChanged()

    def _clearsel_fired(self,event):
        if isinstance(event,list):
            self.scatter.index.metadata['selections'] = event
        else:
            self.scatter.index.metadata['selections'] = list()

    def _lasso_handler(self,name,new):
        if name == 'selection_changed':
            lassomask = self.scatter.index.metadata['selection'].astype(int)
            clickmask = np.zeros_like(lassomask)
            clickmask[self.scatter.index.metadata['selections']] = 1

            if self.lassomode == 'add':
                mask = clickmask | lassomask
            elif self.lassomode == 'remove':
                mask = clickmask & ~lassomask
            elif self.lassomode == 'invert':
                mask = np.logical_xor(clickmask,lassomask)
            else:
                raise TraitsError('lassomode is in invalid state')

            self.scatter.index.metadata['selections'] = list(np.where(mask)[0])
        elif name == 'selection_completed':
            self.scatter.overlays[-1].visible = False
        elif name == 'updated':
            self.scatter.overlays[-1].visible = True
        else:
            raise ValueError('traits event name %s invalid'%name)

    def _immediate_handler(self):
        sel = self.selectedi
        if len(sel) > 1:
            self.clearsel = True
            raise TraitsError('selection error in immediate mode - more than 1 selection')
        elif len(sel)==1:
            if self.lastselaction != 'None':
                setattr(self,self.lastselaction,True)
            del sel[0]

    def _savews_fired(self):
        self._savedws = self.weights.copy()

    def _loadws_fired(self):
        self.weights = self._savedws
        self._savews_fired()

    def _get_selectedi(self):
        return self.scatter.index.metadata['selections']


    @on_trait_change('data,ytype',post_init=True)
    def dataChanged(self):
        """
        Updates the application state if the fit data are altered - the GUI will
        know if you give it a new data array, but not if the data is changed
        in-place.
        """
        pd = self.plot.data
        #TODO:make set_data apply to both simultaneously?
        pd.set_data('xdata',self.data[0])
        pd.set_data('ydata',self.data[1])

        self.updatemodelplot = False

    @on_trait_change('weights',post_init=True)
    def weightsChanged(self):
        """
        Updates the application state if the weights/error bars for this model
        are changed - the GUI will automatically do this if you give it a new
        set of weights array, but not if they are changed in-place.
        """
        weights = self.weights
        if 'errorplots' in self.trait_names():
            #TODO:switch this to updating error bar data/visibility changing
            if self.errorplots is not None:
                self.plot.remove(self.errorplots[0])
                self.plot.remove(self.errorplots[1])
                self.errorbarplots = None

            if len(weights.shape)==2 and weights.shape[0]==2:
                xerr,yerr = 1/weights

                high = ArrayDataSource(self.scatter.index.get_data()+xerr)
                low = ArrayDataSource(self.scatter.index.get_data()-xerr)
                ebpx = ErrorBarPlot(orientation='v',
                                   value_high = high,
                                   value_low = low,
                                   index = self.scatter.value,
                                   value = self.scatter.index,
                                   index_mapper = self.scatter.value_mapper,
                                   value_mapper = self.scatter.index_mapper
                                )
                self.plot.add(ebpx)

                high = ArrayDataSource(self.scatter.value.get_data()+yerr)
                low = ArrayDataSource(self.scatter.value.get_data()-yerr)
                ebpy = ErrorBarPlot(value_high = high,
                                   value_low = low,
                                   index = self.scatter.index,
                                   value = self.scatter.value,
                                   index_mapper = self.scatter.index_mapper,
                                   value_mapper = self.scatter.value_mapper
                                )
                self.plot.add(ebpy)

                self.errorplots = (ebpx,ebpy)

        while len(weights.shape)>1:
            weights = np.sum(weights**2,axis=0)
        self.plot.data.set_data('weights',weights)
        self.plot.plots['data'][0].color_mapper.range.refresh()

        if self.weightsvary:
            if self.colorbar not in self.plotcontainer.components:
                self.plotcontainer.add(self.colorbar)
                self.plotcontainer.request_redraw()
        elif self.colorbar in self.plotcontainer.components:
                self.plotcontainer.remove(self.colorbar)
                self.plotcontainer.request_redraw()


    def _weighttype_changed(self, name, old, new):
        if old == 'custom':
            self._customweights = self.weights

        if new == 'custom':
            self.weights = self._customweights #if hasattr(self,'_customweights') else np.ones_like(self.data[0])
        elif new == 'equal':
            self.weights = np.ones_like(self.data[0])
        elif new == 'lin bins':
            self.weights = binned_weights(self.data[0],10,False)
        elif new == 'log bins':
            self.weights = binned_weights(self.data[0],10,True)
        else:
            raise TraitError('Invalid Enum value on weighttype')

    def getModelInitStr(self):
        """
        Generates a python code string that can be used to generate a model with
        parameters matching the model in this :class:`FitGui`.

        :returns: initializer string

        """
        mod = self.tmodel.model
        if mod is None:
            return 'None'
        else:
            parstrs = []
            for p,v in mod.pardict.iteritems():
                parstrs.append(p+'='+str(v))
            if mod.__class__._pars is None: #varargs need to have the first argument give the right number
                varcount = len(mod.params)-len(mod.__class__._statargs)
                parstrs.insert(0,str(varcount))
            return '%s(%s)'%(mod.__class__.__name__,','.join(parstrs))

    def getModelObject(self):
        """
        Gets the underlying object representing the model for this fit.

        :returns: The :class:`pymodelfit.core.FunctionModel1D` object.
        """
        return self.tmodel.model
Beispiel #23
0
class Plot3D(HasTraits):

    plot = Instance(Component)
    name = 'Scan Plot'
    id = 'radpy.plugins.BeamAnalysis.ChacoPlot'
    current_dose = Float()

    traits_view = View(Group(Item('plot',
                                  editor=ComponentEditor(size=(400, 300)),
                                  show_label=False),
                             Item('current_dose'),
                             id='radpy.plugins.BeamAnalysis.ChacoPlotItems'),
                       resizable=True,
                       title="Scan Plot",
                       width=400,
                       height=300,
                       id='radpy.plugins.BeamAnalysis.ChacoPlotView')
    plot_type = String

    # These are the indices into the cube that each of the image plot views
    # will show; the default values are non-zero just to make it a little
    # interesting.
    slice_x = 0
    slice_y = 0
    slice_z = 0

    num_levels = Int(15)
    colormap = Any
    colorcube = Any

    #---------------------------------------------------------------------------
    # Private Traits
    #---------------------------------------------------------------------------

    _cmap = Trait(jet, Callable)

    def _update_indices(self, token, axis, index_x):
        for i in self.center.overlays:
            if isinstance(i, MyLineInspector):
                i.update_index(token, axis, index_x)

        for i in self.right.overlays:
            if isinstance(i, MyLineInspector):
                i.update_index(token, axis, index_x)

        for i in self.bottom.overlays:
            if isinstance(i, MyLineInspector):
                i.update_index(token, axis, index_x)

    def _update_positions(self, token, value):

        for i in self.center.overlays:
            if isinstance(i, TextBoxOverlay) and i.token == token:
                i.text = set.difference(set("xyz"),
                                        set(token)).pop() + ' = %.2f' % value

        for i in self.right.overlays:
            if isinstance(i, TextBoxOverlay) and i.token == token:
                i.text = set.difference(set("xyz"),
                                        set(token)).pop() + ' = %.2f' % value

        for i in self.bottom.overlays:
            if isinstance(i, TextBoxOverlay) and i.token == token:
                i.text = set.difference(set("xyz"),
                                        set(token)).pop() + ' = %.2f' % value

    def _index_callback(self, tool, axis, index, value):
        plane = tool.token
        if plane == "xy":
            if axis == "index" or axis == "index_x":
                self.slice_x = index
                self._update_indices("xz", "index_x", index)
                self._update_positions("yz", value)
            else:
                self.slice_y = index
                self._update_indices("yz", "index_y", index)
                self._update_positions("xz", value)
        elif plane == "yz":
            if axis == "index" or axis == "index_x":
                self.slice_z = index
                self._update_indices("xz", "index_y", index)
                self._update_positions("xy", value)
            else:
                self.slice_y = index
                self._update_indices("xy", "index_y", index)
                self._update_positions("xz", value)
        elif plane == "xz":
            if axis == "index" or axis == "index_x":
                self.slice_x = index
                self._update_indices("xy", "index_x", index)
                self._update_positions("yz", value)
            else:
                self.slice_z = index
                self._update_indices("yz", "index_x", index)
                self._update_positions("xy", value)
        else:
            warnings.warn("Unrecognized plane for _index_callback: %s" % plane)

        self._update_images()
        self.center.invalidate_and_redraw()
        self.right.invalidate_and_redraw()
        self.bottom.invalidate_and_redraw()
        return

    def _plot_type_default(self):
        return '3D_dose'

    def _plot_default(self):
        return self._create_plot_component()

    def _add_plot_tools(self, imgplot, token):
        """ Add LineInspectors, ImageIndexTool, and ZoomTool to the image plots. """

        imgplot.overlays.append(
            ZoomTool(component=imgplot,
                     tool_mode="box",
                     enable_wheel=False,
                     always_on=False))
        imgplot.overlays.append(
            MyLineInspector(imgplot,
                            axis="index_y",
                            color="grey",
                            inspect_mode="indexed",
                            callback=self._index_callback,
                            token=token))

        imgplot.overlays.append(
            MyLineInspector(imgplot,
                            axis="index_x",
                            color="grey",
                            inspect_mode="indexed",
                            callback=self._index_callback,
                            token=token))
        imgplot.overlays.append(
            MyTextBoxOverlay(imgplot,
                             token=token,
                             align='lr',
                             bgcolor='white',
                             font='Arial 12'))

    def _create_plot_component(self):
        container = GridPlotContainer(padding=30,
                                      fill_padding=True,
                                      bgcolor="white",
                                      use_backbuffer=True,
                                      shape=(2, 2),
                                      spacing=(30, 30))
        self.container = container

        #
        return container
#

    def add_plot(self, label, beam):

        #        container = GridPlotContainer(padding=20, fill_padding=True,
        #                                      bgcolor="white", use_backbuffer=True,
        #                                      shape=(2,2), spacing=(12,12))
        #        self.container = container
        self.plotdata = ArrayPlotData()
        self.model = beam.Data
        self.model.z_axis = self.model.z_axis[::-1]
        cmap = jet
        self._update_model(cmap)
        self.plotdata.set_data("xy", self.model.dose)
        self._update_images()

        # Center Plot
        centerplot = Plot(self.plotdata,
                          resizable='hv',
                          height=150,
                          width=150,
                          padding=0)
        centerplot.default_origin = 'top left'
        imgplot = centerplot.img_plot("xy",
                                      xbounds=(self.model.x_axis[0],
                                               self.model.x_axis[-1]),
                                      ybounds=(self.model.y_axis[0],
                                               self.model.y_axis[-1]),
                                      colormap=cmap)[0]

        imgplot.origin = 'top left'
        self._add_plot_tools(imgplot, "xy")
        left_axis = PlotAxis(centerplot, orientation='left', title='y')
        bottom_axis = PlotAxis(centerplot,
                               orientation='bottom',
                               title='x',
                               title_spacing=30)
        centerplot.underlays.append(left_axis)
        centerplot.underlays.append(bottom_axis)
        self.center = imgplot

        # Right Plot
        rightplot = Plot(self.plotdata,
                         height=150,
                         width=150,
                         resizable="hv",
                         padding=0)
        rightplot.default_origin = 'top left'
        rightplot.value_range = centerplot.value_range
        imgplot = rightplot.img_plot("yz",
                                     xbounds=(self.model.z_axis[0],
                                              self.model.z_axis[-1]),
                                     ybounds=(self.model.y_axis[0],
                                              self.model.y_axis[-1]),
                                     colormap=cmap)[0]
        imgplot.origin = 'top left'
        self._add_plot_tools(imgplot, "yz")
        left_axis = PlotAxis(rightplot, orientation='left', title='y')
        bottom_axis = PlotAxis(rightplot,
                               orientation='bottom',
                               title='z',
                               title_spacing=30)
        rightplot.underlays.append(left_axis)
        rightplot.underlays.append(bottom_axis)
        self.right = imgplot

        # Bottom Plot
        bottomplot = Plot(self.plotdata,
                          height=150,
                          width=150,
                          resizable="hv",
                          padding=0)
        bottomplot.index_range = centerplot.index_range
        imgplot = bottomplot.img_plot("xz",
                                      xbounds=(self.model.x_axis[0],
                                               self.model.x_axis[-1]),
                                      ybounds=(self.model.z_axis[0],
                                               self.model.z_axis[-1]),
                                      colormap=cmap)[0]
        self._add_plot_tools(imgplot, "xz")
        left_axis = PlotAxis(bottomplot, orientation='left', title='z')
        bottom_axis = PlotAxis(bottomplot,
                               orientation='bottom',
                               title='x',
                               title_spacing=30)
        bottomplot.underlays.append(left_axis)
        bottomplot.underlays.append(bottom_axis)
        self.bottom = imgplot

        # Create Container and add all Plots
        #        container = GridPlotContainer(padding=20, fill_padding=True,
        #                                      bgcolor="white", use_backbuffer=True,
        #                                      shape=(2,2), spacing=(12,12))
        self.container.add(centerplot)
        self.container.add(rightplot)
        self.container.add(bottomplot)

        #return Window(self, -1, component=container)

        #        return container

        return label

    def _update_images(self):
        """ Updates the image data in self.plotdata to correspond to the 
        slices given.
        """
        cube = self.colorcube
        pd = self.plotdata
        # These are transposed because img_plot() expects its data to be in
        # row-major order

        pd.set_data("xy", numpy.transpose(cube[:, :, self.slice_z], (1, 0, 2)))
        pd.set_data("xz", numpy.transpose(cube[:, self.slice_y, :], (1, 0, 2)))
        pd.set_data("yz", cube[self.slice_x, :, :])
        self.current_dose = self.model.dose[self.slice_x][self.slice_y][
            self.slice_z]
        return

    def _update_model(self, cmap):
        range = DataRange1D(low=numpy.amin(self.model.dose),
                            high=numpy.amax(self.model.dose))
        self.colormap = cmap(range)
        self.colorcube = (self.colormap.map_screen(self.model.dose) *
                          255).astype(numpy.uint8)
class AnimationDemo(HasTraits):
    numdots = Int(3000)
    box = Instance(DotComponent)
    points3d = Array
    target_points3d = Array
    angle_x = Float()
    angle_y = Float()
    angle_z = Float()
    frame_count = Int(0)
    view = View(Item("box", editor=ComponentEditor(), show_label=False),
                resizable=True,
                width=600,
                height=600,
                title="变形",
                handler=AnimationHandler())

    def __init__(self, **traits):
        super(AnimationDemo, self).__init__(**traits)
        self.box = DotComponent()
        self.points3d = np.zeros((4, self.numdots))
        self.points3d[3, :] = 1
        self.target_points3d = np.zeros((4, self.numdots))
        self.target_points3d[3, :] = 1

        self.points3d[:3, :] = self.get_object_points()
        self.target_points3d[:3, :] = self.get_object_points()
        self.frame_count = 0

    def get_object_points(self):
        r = np.random.rand(self.numdots)
        t = np.linspace(0, np.pi * 2, self.numdots)
        a = r * np.pi - np.pi / 2
        l = [np.cos(a), np.cos(t), np.sin(a), np.sin(t), np.ones(self.numdots)]
        c = random.choice
        return c(l) * c(l), c(l) * c(l), c(l) * c(l)

    def make_matrix(self):
        s = np.sin([self.angle_x, self.angle_y, self.angle_z])
        c = np.cos([self.angle_x, self.angle_y, self.angle_z])
        #构造三个轴的旋转矩阵
        mx = np.eye(4)
        mx[[1, 1, 2, 2], [1, 2, 1, 2]] = c[2], -s[2], s[2], c[2]
        my = np.eye(4)
        my[[0, 0, 2, 2], [0, 2, 0, 2]] = c[1], s[1], -s[1], c[1]
        mz = np.eye(4)
        mz[[0, 0, 1, 1], [0, 1, 0, 1]] = c[0], -s[0], s[0], c[0]
        #进行坐标旋转
        return np.dot(mz, np.dot(my, mx))

    def projection(self):
        m = self.make_matrix()
        result = np.dot(m, self.points3d)
        self.box.projection_points(result)

    def on_timer(self):
        self.angle_z += 0.02
        self.angle_x += 0.02
        self.projection()
        self.box.request_redraw()
        self.points3d[:3, :] += 0.04 * (self.target_points3d[:3, :] -
                                        self.points3d[:3, :])
        self.points3d[:3, :] += np.random.normal(0, 0.004, (3, self.numdots))
        self.frame_count += 1
        if self.frame_count > 300:
            self.frame_count = 0
            self.target_points3d[:3, :] = self.get_object_points()
class ViewportDefiner(HasTraits):
    width = traits.Int
    height = traits.Int
    display_name = traits.String
    plot = Instance(Component)
    linedraw = Instance(LineSegmentTool)
    viewport_id = traits.String
    display_mode = traits.Trait('white on black', 'black on white')
    display_server = traits.Any
    display_info = traits.Any
    show_grid = traits.Bool

    traits_view = View(
        Group(Item('display_mode'),
              Item('display_name'),
              Item('viewport_id'),
              Item('plot', editor=ComponentEditor(), show_label=False),
              orientation="vertical"),
        resizable=True,
    )

    def __init__(self, *args, **kwargs):
        super(ViewportDefiner, self).__init__(*args, **kwargs)

        #find our index in the viewport list
        viewport_ids = []
        self.viewport_idx = -1
        for i, obj in enumerate(self.display_info['virtualDisplays']):
            viewport_ids.append(obj['id'])
            if obj['id'] == self.viewport_id:
                self.viewport_idx = i

        if self.viewport_idx == -1:
            raise Exception("Could not find viewport (available ids: %s)" %
                            ",".join(viewport_ids))

        self._update_image()

        self.fqdn = self.display_name + '/display/virtualDisplays'
        self.this_virtual_display = self.display_info['virtualDisplays'][
            self.viewport_idx]

        all_points_ok = True
        # error check
        for (x, y) in self.this_virtual_display['viewport']:
            if (x >= self.width) or (y >= self.height):
                all_points_ok = False
                break
        if all_points_ok:
            self.linedraw.points = self.this_virtual_display['viewport']
        else:
            self.linedraw.points = []
            rospy.logwarn('invalid points')
        self._update_image()

    def _update_image(self):
        self._image = np.zeros((self.height, self.width, 3), dtype=np.uint8)
        fill_polygon.fill_polygon(self.linedraw.points, self._image)

        if self.show_grid:
            # draw red horizontal stripes
            for i in range(0, self.height, 100):
                self._image[i:i + 10, :, 0] = 255

            # draw blue vertical stripes
            for i in range(0, self.width, 100):
                self._image[:, i:i + 10, 2] = 255

        if hasattr(self, '_pd'):
            self._pd.set_data("imagedata", self._image)
        self.send_array()
        if len(self.linedraw.points) >= 3:
            self.update_ROS_params()

    def _plot_default(self):
        self._pd = ArrayPlotData()
        self._pd.set_data("imagedata", self._image)

        plot = Plot(self._pd, default_origin="top left")
        plot.x_axis.orientation = "top"
        img_plot = plot.img_plot("imagedata")[0]

        plot.bgcolor = "white"

        # Tweak some of the plot properties
        plot.title = "Click to add points, press Enter to clear selection"
        plot.padding = 50
        plot.line_width = 1

        # Attach some tools to the plot
        pan = PanTool(plot, drag_button="right", constrain_key="shift")
        plot.tools.append(pan)
        zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
        plot.overlays.append(zoom)

        return plot

    def _linedraw_default(self):
        linedraw = LineSegmentTool(self.plot, color=(0.5, 0.5, 0.9, 1.0))
        self.plot.overlays.append(linedraw)
        linedraw.on_trait_change(self.points_changed, 'points[]')
        return linedraw

    def points_changed(self):
        self._update_image()

    @traits.on_trait_change('display_mode')
    def send_array(self):
        # create an array
        if self.display_mode.endswith(' on black'):
            bgcolor = (0, 0, 0, 1)
        elif self.display_mode.endswith(' on white'):
            bgcolor = (1, 1, 1, 1)

        if self.display_mode.startswith('black '):
            color = (0, 0, 0, 1)
        elif self.display_mode.startswith('white '):
            color = (1, 1, 1, 1)

        self.display_server.show_pixels(self._image)

    def get_viewport_verts(self):
        # convert to integers
        pts = [(fill_polygon.posint(x, self.width - 1),
                fill_polygon.posint(y, self.height - 1))
               for (x, y) in self.linedraw.points]
        # convert to list of lists for maximal json compatibility
        return [list(x) for x in pts]

    def update_ROS_params(self):
        viewport_verts = self.get_viewport_verts()
        self.this_virtual_display['viewport'] = viewport_verts
        self.display_info['virtualDisplays'][
            self.viewport_idx] = self.this_virtual_display
        rospy.set_param(self.fqdn, self.display_info['virtualDisplays'])
Beispiel #26
0
class ConnectionMatrixViewer(HasTraits):

    tplot = Instance(Plot)
    plot = Instance(Component)
    custtool = Instance(CustomTool)
    colorbar = Instance(ColorBar)

    fro = Any
    to = Any
    data = None
    val = Float
    nodelabels = Any

    traits_view = View(
        Group(Item('plot',
                   editor=ComponentEditor(size=(800, 600)),
                   show_label=False),
              HGroup(
                  Item('fro', label="From", style='readonly', springy=True),
                  Item('to', label="To", style='readonly', springy=True),
                  Item('val', label="Value", style='readonly', springy=True),
              ),
              orientation="vertical"),
        Item('data_name', label="Edge key"),
        # handler=CustomHandler(),
        resizable=True,
        title="Connection Matrix Viewer")

    def __init__(self, nodelabels, matdict, **traits):
        """ Starts a matrix inspector
        
        Parameters
        ----------
        nodelables : list
            List of strings of labels for the rows of the matrix
        matdict : dictionary
            Keys are the edge type and values are NxN Numpy arrays """
        super(HasTraits, self).__init__(**traits)

        self.add_trait('data_name', Enum(matdict.keys()))

        self.data_name = matdict.keys()[0]
        self.data = matdict
        self.nodelables = nodelabels
        self.plot = self._create_plot_component()

        # set trait notification on customtool
        self.custtool.on_trait_change(self._update_fields, "xval")
        self.custtool.on_trait_change(self._update_fields, "yval")

    def _data_name_changed(self, old, new):
        self.pd.set_data("imagedata", self.data[self.data_name])
        #self.my_plot.set_value_selection((0, 2))
        self.tplot.title = "Connection Matrix for %s" % self.data_name

    def _update_fields(self):

        # map mouse location to array index
        frotmp = int(round(self.custtool.yval) - 1)
        totmp = int(round(self.custtool.xval) - 1)

        # check if within range
        sh = self.data[self.data_name].shape
        # assume matrix whose shape is (# of rows, # of columns)
        if frotmp >= 0 and frotmp < sh[0] and totmp >= 0 and totmp < sh[1]:
            row = " (index: %i" % (frotmp + 1) + ")"
            col = " (index: %i" % (totmp + 1) + ")"
            self.fro = " " + str(self.nodelables[frotmp]) + row
            self.to = " " + str(self.nodelables[totmp]) + col
            self.val = self.data[self.data_name][frotmp, totmp]

    def _create_plot_component(self):

        # Create a plot data object and give it this data
        self.pd = ArrayPlotData()
        self.pd.set_data("imagedata", self.data[self.data_name])

        # find dimensions
        xdim = self.data[self.data_name].shape[1]
        ydim = self.data[self.data_name].shape[0]

        # Create the plot
        self.tplot = Plot(self.pd, default_origin="top left")
        self.tplot.x_axis.orientation = "top"
        self.tplot.img_plot("imagedata",
                            name="my_plot",
                            xbounds=(0.5, xdim + 0.5),
                            ybounds=(0.5, ydim + 0.5),
                            colormap=jet)

        # Tweak some of the plot properties
        self.tplot.title = "Connection Matrix for %s" % self.data_name
        self.tplot.padding = 80

        # Right now, some of the tools are a little invasive, and we need the
        # actual CMapImage object to give to them
        self.my_plot = self.tplot.plots["my_plot"][0]

        # Attach some tools to the plot
        self.tplot.tools.append(PanTool(self.tplot))
        zoom = ZoomTool(component=self.tplot, tool_mode="box", always_on=False)
        self.tplot.overlays.append(zoom)

        # my custom tool to get the connection information
        self.custtool = CustomTool(self.tplot)
        self.tplot.tools.append(self.custtool)

        # Create the colorbar, handing in the appropriate range and colormap
        colormap = self.my_plot.color_mapper
        self.colorbar = ColorBar(
            index_mapper=LinearMapper(range=colormap.range),
            color_mapper=colormap,
            plot=self.my_plot,
            orientation='v',
            resizable='v',
            width=30,
            padding=20)

        self.colorbar.padding_top = self.tplot.padding_top
        self.colorbar.padding_bottom = self.tplot.padding_bottom

        # create a range selection for the colorbar
        self.range_selection = RangeSelection(component=self.colorbar)
        self.colorbar.tools.append(self.range_selection)
        self.colorbar.overlays.append(
            RangeSelectionOverlay(component=self.colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray"))

        # we also want to the range selection to inform the cmap plot of
        # the selection, so set that up as well
        self.range_selection.listeners.append(self.my_plot)

        # Create a container to position the plot and the colorbar side-by-side
        container = HPlotContainer(use_backbuffer=True)
        container.add(self.tplot)
        container.add(self.colorbar)
        container.bgcolor = "white"

        return container
Beispiel #27
0
class ImageTimeseriesViewer(BaseDataViewer):

    plots = Dict

    plotdata = Instance(ArrayPlotData)
    image = Any

    time_index = Int
    time_points = Int
    time = Any

    traits_view = View(
        Tabbed(
            HSplit(
                VGroup(
                    Item('plot',
                         editor=ComponentEditor(),
                         show_label=False,
                         resizable=True,
                         label='View'),
                    Item("time_index",
                         style='custom',
                         editor=RangeEditor(low=0, high_name='time_points')),
                    Item("time", style='readonly'),
                ),
                Item('tasks', style='custom', show_label=False, label='Tasks'),
            ),
            Item('results', style='custom', show_label=False, label='Results'),
        ), )

    def _time_points_default(self):
        return self.data.shape[0] - 1

    def _time_index_default(self):
        return 0

    def _time_default(self):
        return self.get_data_time(0, 0)

    def _plotdata_default(self):
        data = self.get_data_slice(0, 0)
        plotdata = ArrayPlotData()
        plotdata.set_data('xy', data)
        return plotdata

    def _time_index_changed(self):
        self.select_xy_slice(self.time_index)

    def select_xy_slice(self, t):
        data = self.get_data_slice(t, 0)
        self.time = self.get_data_time(t, 0)
        self.plotdata.set_data('xy', data)
        self.image.invalidate_and_redraw()

    def reset(self):
        pass

    def redraw(self):
        self.image.invalidate_and_redraw()

    def get_plot(self):
        pixel_sizes = self.data_source.pixel_sizes
        shape = self.data.shape[1:]
        m = min(pixel_sizes)
        s = [int(d * sz / m) for d, sz in zip(shape, pixel_sizes)]
        plot_sizes = dict(xy=(s[1], s[0]))
        plot = Plot(
            self.plotdata,
            padding=30,
            fixed_preferred_size=plot_sizes['xy'],
        )
        image = plot.img_plot('xy', colormap=bone)[0]
        image.overlays.append(ZoomTool(image))
        image.tools.append(PanTool(image, drag_button='right'))
        imgtool = ImageInspectorTool(image)
        image.tools.append(imgtool)
        overlay = ImageInspectorOverlay(component=image,
                                        bgcolor='white',
                                        image_inspector=imgtool)
        image.overlays.append(overlay)
        self.image = image

        self.plots = dict(xy=image)
        return plot
class IFSDesigner(HasTraits):
    plot = Instance(Plot)
    clear = Bool(False)
    draw = Bool(False)
    timer = Instance(Timer)
    ifs_names = List()
    ifs_points = List()
    current_name = Str()
    save_button = Button(u"保存当前IFS")
    unsave_button = Button(u"删除当前IFS")    
    
    view = View(
        HGroup(
            Item("current_name", editor = EnumEditor(name="object.ifs_names")),
            Item("save_button"),                
            Item("unsave_button"),
            show_labels=False
        ),
        Item("plot", editor=ComponentEditor(),show_label=False),
        resizable=True,
        width = 500, 
        height = 500,
        title = u"IFS图形设计器"
    )
    
    def __init__(self):       
        self.data = ArrayPlotData()
        self.set_empty_data()
        self.plot = Plot(self.data, padding=10)
        scatter = self.plot.plot(("x","y", "c"), type="cmap_scatter", 
            marker_size=1, color_mapper=make_color_map(), line_width=0)[0]
        self.plot.x_grid.visible = False
        self.plot.y_grid.visible = False
        self.plot.x_axis.visible = False
        self.plot.y_axis.visible = False
        self.tool = TrianglesTool(self.plot)
        self.plot.overlays.append(self.tool)

        try:
            with file("ifs_chaco.data","rb") as f:
                tmp = pickle.load(f)
                self.ifs_names = [x[0] for x in tmp]                
                self.ifs_points = [np.array(x[1]) for x in tmp]

            if len(self.ifs_names) > 0:
                self.current_name = self.ifs_names[-1]
        except:
            pass 
        
        self.tool.on_trait_change(self.triangle_changed, 'changed')
        self.timer = Timer(10, self.ifs_calculate)       

    def set_empty_data(self):
        self.data["x"] = np.array([])
        self.data["y"] = np.array([])
        self.data["c"] = np.array([])        
        
    def triangle_changed(self):
        count = len(self.tool.points)
        if count % 3 == 0:
            self.set_empty_data()        
            
        if count < 9:
            self.draw = False
            
        if count >= 9 and count % 3  == 0:
            self.clear = True
        
    def ifs_calculate(self):
        if self.clear == True:
            self.clear = False
            self.initpos = [0, 0]
            # 不绘制迭代的初始100个点
            x, y, c = ifs( self.tool.get_areas(), self.tool.get_eqs(), self.initpos, 100)
            self.initpos = [x[-1], y[-1]]
            self.draw = True
        
        if self.draw and len(self.data["x"]) < ITER_COUNT * ITER_TIMES:
            x, y, c = ifs( self.tool.get_areas(), self.tool.get_eqs(), self.initpos, ITER_COUNT)
            ox, oy, oc = self.data["x"], self.data["y"], self.data["c"]
            if np.max(np.abs(x)) < 1000000 and np.max(np.abs(y)) < 1000000:
                self.initpos = [x[-1], y[-1]]
                x, y, z = np.hstack((ox, x)), np.hstack((oy, y)), np.hstack((oc, c))
                self.data["x"], self.data["y"], self.data["c"] = x, y, z
                # 调整绘图范围,保持X-Y轴的比例为1:1
                xmin, xmax = np.min(x), np.max(x)                
                ymin, ymax = np.min(y), np.max(y)
                xptp, yptp = xmax - xmin, ymax-ymin
                xcenter, ycenter =(xmax + xmin) / 2.0 , (ymax + ymin) / 2.0
                w, h = float(self.plot.width), float(self.plot.height)
                scale = max(xptp/w , yptp/h)
                self.plot.index_range.low = xcenter - 0.5*scale*w
                self.plot.index_range.high = xcenter + 0.5*scale*w
                self.plot.value_range.low = ycenter - 0.5*scale*h
                self.plot.value_range.high = ycenter + 0.5*scale

    def _current_name_changed(self):
        index = self.ifs_names.index(self.current_name)
        self.tool.points = list(self.ifs_points[index])
        self.tool.changed = True
        self.clear = True
                
    def _save_button_fired(self):
        """
        保存按钮处理
        """
        ask = AskName(name = self.current_name)
        if ask.configure_traits():
            if ask.name not in self.ifs_names:
                self.ifs_names.append( ask.name )
                self.ifs_points.append( self.tool.points[:] )
            else:
                index = self.ifs_names.index(ask.name)
                self.ifs_names[index] = ask.name
                self.ifs_points[index] = self.tool.points[:]    
            self.save_data()
            self.current_name = ask.name
            
    def _unsave_button_fired(self):
        index = self.ifs_names.index(self.current_name)
        del self.ifs_names[index]
        del self.ifs_points[index]
        if index >= self.ifs_names[index]: index -= 1
        self.current_name = self.ifs_names[index]
        self.save_data()
        
    def save_data(self):               
        with file("IFS_chaco.data", "wb") as f:
            pickle.dump(zip(self.ifs_names, self.ifs_points), f)         
class GenerateProjectorCalibration(HasTraits):
    #width = traits.Int
    #height = traits.Int
    display_id = traits.String
    plot = Instance(Component)
    linedraw = Instance(LineSegmentTool)
    viewport_id = traits.String('viewport_0')
    display_mode = traits.Trait('white on black', 'black on white')
    client = traits.Any
    blit_compressed_image_proxy = traits.Any

    set_display_server_mode_proxy = traits.Any

    traits_view = View(
                    Group(
        Item('display_mode'),
        Item('viewport_id'),
                        Item('plot', editor=ComponentEditor(),
                             show_label=False),
                        orientation = "vertical"),
                    resizable=True,
                    )

    def __init__(self,*args,**kwargs):
        display_coords_filename = kwargs.pop('display_coords_filename')
        super( GenerateProjectorCalibration, self ).__init__(*args,**kwargs)

        fd = open(display_coords_filename,mode='r')
        data = pickle.load(fd)
        fd.close()

        self.param_name = 'virtual_display_config_json_string'
        self.fqdn = '/virtual_displays/'+self.display_id + '/' + self.viewport_id
        self.fqpn = self.fqdn + '/' + self.param_name
        self.client = dynamic_reconfigure.client.Client(self.fqdn)

        self._update_image()
        if 1:
            virtual_display_json_str = rospy.get_param(self.fqpn)
            this_virtual_display = json.loads( virtual_display_json_str )

        if 1:
            virtual_display_json_str = rospy.get_param(self.fqpn)
            this_virtual_display = json.loads( virtual_display_json_str )

            all_points_ok = True
            # error check
            for (x,y) in this_virtual_display['viewport']:
                if (x >= self.width) or (y >= self.height):
                    all_points_ok = False
                    break
            if all_points_ok:
                self.linedraw.points = this_virtual_display['viewport']
            # else:
            #     self.linedraw.points = []
            self._update_image()

    def _update_image(self):
        self._image = np.zeros( (self.height, self.width, 3), dtype=np.uint8)
        # draw polygon
        if len(self.linedraw.points)>=3:
            pts = [ (posint(y,self.height-1),posint(x,self.width-1)) for (x,y) in self.linedraw.points]
            mahotas.polygon.fill_polygon(pts, self._image[:,:,0])
            self._image[:,:,0] *= 255
            self._image[:,:,1] = self._image[:,:,0]
            self._image[:,:,2] = self._image[:,:,0]

        # draw red horizontal stripes
        for i in range(0,self.height,100):
            self._image[i:i+10,:,0] = 255

        # draw blue vertical stripes
        for i in range(0,self.width,100):
            self._image[:,i:i+10,2] = 255

        if hasattr(self,'_pd'):
            self._pd.set_data("imagedata", self._image)
        self.send_array()
        if len(self.linedraw.points) >= 3:
            self.update_ROS_params()

    def _plot_default(self):
        self._pd = ArrayPlotData()
        self._pd.set_data("imagedata", self._image)

        plot = Plot(self._pd, default_origin="top left")
        plot.x_axis.orientation = "top"
        img_plot = plot.img_plot("imagedata")[0]

        plot.bgcolor = "white"

        # Tweak some of the plot properties
        plot.title = "Click to add points, press Enter to clear selection"
        plot.padding = 50
        plot.line_width = 1

        # Attach some tools to the plot
        pan = PanTool(plot, drag_button="right", constrain_key="shift")
        plot.tools.append(pan)
        zoom = ZoomTool(component=plot, tool_mode="box", always_on=False)
        plot.overlays.append(zoom)

        return plot

    def _linedraw_default(self):
        linedraw = LineSegmentTool(self.plot,color=(0.5,0.5,0.9,1.0))
        self.plot.overlays.append(linedraw)
        linedraw.on_trait_change( self.points_changed, 'points[]')
        return linedraw

    def points_changed(self):
        self._update_image()

    @traits.on_trait_change('display_mode')
    def send_array(self):
        # create an array
        if self.display_mode.endswith(' on black'):
            bgcolor = (0,0,0,1)
        elif self.display_mode.endswith(' on white'):
            bgcolor = (1,1,1,1)

        if self.display_mode.startswith('black '):
            color = (0,0,0,1)
        elif self.display_mode.startswith('white '):
            color = (1,1,1,1)

        fname = tempfile.mktemp('.png')
        try:
            scipy.misc.imsave(fname, self._image )
            image = freemovr_engine.msg.FreemoVRCompressedImage()
            image.format = 'png'
            image.data = open(fname).read()
            self.blit_compressed_image_proxy(image)
        finally:
            os.unlink(fname)

    def get_viewport_verts(self):
        # convert to integers
        pts = [ (posint(x,self.width-1),posint(y,self.height-1)) for (x,y) in self.linedraw.points]
        # convert to list of lists for maximal json compatibility
        return [ list(x) for x in pts ]
class CMatrixViewer(MatrixViewer):

    tplot = Instance(Plot)
    plot = Instance(Component)
    custtool = Instance(CustomTool)
    colorbar = Instance(ColorBar)

    edge_parameter = Instance(EdgeParameters)
    network_reference = Any
    matrix_data_ref = Any
    labels = Any
    fro = Any
    to = Any
    val = Float

    traits_view = View(Group(Item('plot',
                                  editor=ComponentEditor(size=(800, 600)),
                                  show_label=False),
                             HGroup(
                                 Item('fro',
                                      label="From",
                                      style='readonly',
                                      springy=True),
                                 Item('to',
                                      label="To",
                                      style='readonly',
                                      springy=True),
                                 Item('val',
                                      label="Value",
                                      style='readonly',
                                      springy=True),
                             ),
                             orientation="vertical"),
                       Item('edge_parameter_name', label="Choose edge"),
                       handler=CustomHandler(),
                       resizable=True,
                       title="Matrix Viewer")

    def __init__(self, net_ref, **traits):
        """ net_ref is a reference to a cnetwork """
        super(MatrixViewer, self).__init__(**traits)

        self.network_reference = net_ref
        self.edge_parameter = self.network_reference._edge_para
        self.matrix_data_ref = self.network_reference.datasourcemanager._srcobj.edgeattributes_matrix_dict
        self.labels = self.network_reference.datasourcemanager._srcobj.labels

        # get the currently selected edge
        self.curr_edge = self.edge_parameter.parameterset.name
        # create plot
        self.plot = self._create_plot_component()

        # set trait notification on customtool
        self.custtool.on_trait_change(self._update_fields, "xval")
        self.custtool.on_trait_change(self._update_fields, "yval")

        # add edge parameter enum
        self.add_trait('edge_parameter_name',
                       Enum(self.matrix_data_ref.keys()))
        self.edge_parameter_name = self.curr_edge

    def _edge_parameter_name_changed(self, new):
        # update edge parameter dialog
        self.edge_parameter.set_to_edge_parameter(self.edge_parameter_name)

        # update the data
        self.pd.set_data("imagedata",
                         self.matrix_data_ref[self.edge_parameter_name])

        # set range
        #self.my_plot.set_value_selection((0.0, 1.0))

    def _update_fields(self):
        from numpy import trunc

        # map mouse location to array index
        frotmp = int(trunc(self.custtool.yval))
        totmp = int(trunc(self.custtool.xval))

        # check if within range
        sh = self.matrix_data_ref[self.edge_parameter_name].shape
        # assume matrix whose shape is (# of rows, # of columns)
        if frotmp >= 0 and frotmp < sh[0] and totmp >= 0 and totmp < sh[1]:
            self.fro = self.labels[frotmp]
            self.to = self.labels[totmp]
            self.val = self.matrix_data_ref[self.edge_parameter_name][frotmp,
                                                                      totmp]

    def _create_plot_component(self):

        # we need the matrices!
        # start with the currently selected one
        #nr_nodes = self.matrix_data_ref[curr_edge].shape[0]

        # Create a plot data obect and give it this data
        self.pd = ArrayPlotData()
        self.pd.set_data("imagedata", self.matrix_data_ref[self.curr_edge])

        # Create the plot
        self.tplot = Plot(self.pd, default_origin="top left")
        self.tplot.x_axis.orientation = "top"
        self.tplot.img_plot(
            "imagedata",
            name="my_plot",
            #xbounds=(0,nr_nodes),
            #ybounds=(0,nr_nodes),
            colormap=jet)

        # Tweak some of the plot properties
        self.tplot.title = self.curr_edge
        self.tplot.padding = 50

        # Right now, some of the tools are a little invasive, and we need the
        # actual CMapImage object to give to them
        self.my_plot = self.tplot.plots["my_plot"][0]

        # Attach some tools to the plot
        self.tplot.tools.append(PanTool(self.tplot))
        zoom = ZoomTool(component=self.tplot, tool_mode="box", always_on=False)
        self.tplot.overlays.append(zoom)

        # my custom tool to get the connection information
        self.custtool = CustomTool(self.tplot)
        self.tplot.tools.append(self.custtool)

        # Create the colorbar, handing in the appropriate range and colormap
        colormap = self.my_plot.color_mapper
        self.colorbar = ColorBar(
            index_mapper=LinearMapper(range=colormap.range),
            color_mapper=colormap,
            plot=self.my_plot,
            orientation='v',
            resizable='v',
            width=30,
            padding=20)

        self.colorbar.padding_top = self.tplot.padding_top
        self.colorbar.padding_bottom = self.tplot.padding_bottom

        # TODO: the range selection gives a Segmentation Fault,
        # but why, the matrix_viewer.py example works just fine!
        # create a range selection for the colorbar
        self.range_selection = RangeSelection(component=self.colorbar)
        self.colorbar.tools.append(self.range_selection)
        self.colorbar.overlays.append(
            RangeSelectionOverlay(component=self.colorbar,
                                  border_color="white",
                                  alpha=0.8,
                                  fill_color="lightgray"))

        # we also want to the range selection to inform the cmap plot of
        # the selection, so set that up as well
        #self.range_selection.listeners.append(self.my_plot)

        # Create a container to position the plot and the colorbar side-by-side
        container = HPlotContainer(use_backbuffer=True)
        container.add(self.tplot)
        container.add(self.colorbar)
        container.bgcolor = "white"

        # my_plot.set_value_selection((-1.3, 6.9))

        return container