Exemplo n.º 1
0
def init(data):
    global __data__
    __data__ = data
    #
    fig1 = plt.figure()
    plt1 = namedtuple('plot', [
        'fig', 'handles'
    ])(fig1, namedtuple('Handles', [
        'bar', 'mov_ave'
    ])(*new_cases(__data__.country_large, 'US', __data__.dfs, 'new_cases', 5)))
    #
    fig2 = plt.figure()
    plt2 = namedtuple('plot', ['fig', 'handles'])(
        fig2, namedtuple('Handles', ['confirmed', 'recovered', 'deaths'])(
            total_numbers(__data__.country_large,
                          'US',
                          __data__.dfs,
                          'confirmed',
                          colors=['red']),
            total_numbers(__data__.country_large,
                          'US',
                          __data__.dfs,
                          'recovered',
                          colors=['green']),
            total_numbers(__data__.country_large,
                          'US',
                          __data__.dfs,
                          'deaths',
                          colors=['black'])))

    return namedtuple('figs', ['fig1', 'fig2'])(plt1, plt2)
Exemplo n.º 2
0
    def _make_plot(self):
        plt.close(1)
        margin = {'top': 25, 'bottom': 35, 'left': 35, 'right': 25}
        fig_layout = {'height': '100%', 'width': '100%'}
        self.I_fig = plt.figure(1,
                                title='Stokes I',
                                fig_margin=margin,
                                layout=fig_layout)
        self.I_plot = plt.plot(self.u, self.I)
        plt.xlabel("Δλ / ΔλD")

        plt.close(2)
        self.Q_fig = plt.figure(2,
                                title='Stokes Q',
                                fig_margin=margin,
                                layout=fig_layout)
        self.Q_plot = plt.plot(self.u, self.Q)
        plt.xlabel("Δλ / ΔλD")

        plt.close(3)
        self.U_fig = plt.figure(3,
                                title='Stokes U',
                                fig_margin=margin,
                                layout=fig_layout)
        self.U_plot = plt.plot(self.u, self.U)
        plt.xlabel("Δλ / ΔλD")

        plt.close(4)
        self.V_fig = plt.figure(4,
                                title='Stokes V',
                                fig_margin=margin,
                                layout=fig_layout)
        self.V_plot = plt.plot(self.u, self.V)
        plt.xlabel("Δλ / ΔλD")
Exemplo n.º 3
0
    def create_widget(self, output, plot, dataset, limits):
        self.plot = plot
        self.output = output
        self.dataset = dataset
        self.limits = np.array(limits).tolist()
        self.scale_x = bqplot.LinearScale(min=limits[0][0], max=limits[0][1])
        self.scale_y = bqplot.LinearScale(min=limits[1][0], max=limits[1][1])
        self.scale_rotation = bqplot.LinearScale(min=0, max=1)
        self.scale_size = bqplot.LinearScale(min=0, max=1)
        self.scale_opacity = bqplot.LinearScale(min=0, max=1)
        self.scales = {'x': self.scale_x, 'y': self.scale_y, 'rotation': self.scale_rotation,
                       'size': self.scale_size, 'opacity': self.scale_opacity}

        margin = {'bottom': 30, 'left': 60, 'right': 0, 'top': 0}
        self.figure = plt.figure(self.figure_key, fig=self.figure, scales=self.scales, fig_margin=margin)
        self.figure.layout.min_width = '900px'
        plt.figure(fig=self.figure)
        self.figure.padding_y = 0
        x = np.arange(0, 10)
        y = x ** 2
        self._fix_scatter = s = plt.scatter(x, y, visible=False, rotation=x, scales=self.scales)
        self._fix_scatter.visible = False
        # self.scale_rotation = self.scales['rotation']
        src = ""  # vaex.image.rgba_to_url(self._create_rgb_grid())
        # self.scale_x.min, self.scale_x.max = self.limits[0]
        # self.scale_y.min, self.scale_y.max = self.limits[1]
        self.image = bqplot_image.Image(scales=self.scales, src=src, x=self.scale_x.min, y=self.scale_y.max,
                                           width=self.scale_x.max - self.scale_x.min, height=-(self.scale_y.max - self.scale_y.min))
        self.figure.marks = self.figure.marks + [self.image]
        # self.figure.animation_duration = 500
        self.figure.layout.width = '100%'
        self.figure.layout.max_width = '500px'
        self.scatter = s = plt.scatter(x, y, visible=False, rotation=x, scales=self.scales, size=x, marker="arrow")
        self.panzoom = bqplot.PanZoom(scales={'x': [self.scale_x], 'y': [self.scale_y]})
        self.figure.interaction = self.panzoom
        # self.figure.axes[0].label = self.x
        # self.figure.axes[1].label = self.y

        self.scale_x.observe(self._update_limits, "min")
        self.scale_x.observe(self._update_limits, "max")
        self.scale_y.observe(self._update_limits, "min")
        self.scale_y.observe(self._update_limits, "max")
        self.observe(self._update_scales, "limits")

        self.image.observe(self._on_view_count_change, 'view_count')
        self.control_widget = widgets.VBox()
        self.widget = widgets.VBox(children=[self.control_widget, self.figure])
        self.create_tools()
Exemplo n.º 4
0
    def create_widget(self, output, plot, dataset, limits):
        self.plot = plot
        self.output = output
        self.dataset = dataset
        self.limits = np.array(limits).tolist()
        self.scale_x = bqplot.LinearScale(min=limits[0][0].item(), max=limits[0][1].item())
        self.scale_y = bqplot.LinearScale(min=limits[1][0].item(), max=limits[1][1].item())
        self.scale_rotation = bqplot.LinearScale(min=0, max=1)
        self.scale_size = bqplot.LinearScale(min=0, max=1)
        self.scale_opacity = bqplot.LinearScale(min=0, max=1)
        self.scales = {'x': self.scale_x, 'y': self.scale_y, 'rotation': self.scale_rotation,
                       'size': self.scale_size, 'opacity': self.scale_opacity}

        margin = {'bottom': 30, 'left': 60, 'right': 0, 'top': 0}
        self.figure = plt.figure(self.figure_key, fig=self.figure, scales=self.scales, fig_margin=margin)
        self.figure.layout.min_width = '900px'
        plt.figure(fig=self.figure)
        self.figure.padding_y = 0
        x = np.arange(0, 10)
        y = x ** 2
        self._fix_scatter = s = plt.scatter(x, y, visible=False, rotation=x, scales=self.scales)
        self._fix_scatter.visible = False
        # self.scale_rotation = self.scales['rotation']
        src = ""  # vaex.image.rgba_to_url(self._create_rgb_grid())
        # self.scale_x.min, self.scale_x.max = self.limits[0]
        # self.scale_y.min, self.scale_y.max = self.limits[1]
        self.image = bqplot_image.Image(scales=self.scales, src=src, x=self.scale_x.min, y=self.scale_y.max,
                                           width=self.scale_x.max - self.scale_x.min, height=-(self.scale_y.max - self.scale_y.min))
        self.figure.marks = self.figure.marks + [self.image]
        # self.figure.animation_duration = 500
        self.figure.layout.width = '100%'
        self.figure.layout.max_width = '500px'
        self.scatter = s = plt.scatter(x, y, visible=False, rotation=x, scales=self.scales, size=x, marker="arrow")
        self.panzoom = bqplot.PanZoom(scales={'x': [self.scale_x], 'y': [self.scale_y]})
        self.figure.interaction = self.panzoom
        # self.figure.axes[0].label = self.x
        # self.figure.axes[1].label = self.y

        self.scale_x.observe(self._update_limits, "min")
        self.scale_x.observe(self._update_limits, "max")
        self.scale_y.observe(self._update_limits, "min")
        self.scale_y.observe(self._update_limits, "max")
        self.observe(self._update_scales, "limits")

        self.image.observe(self._on_view_count_change, 'view_count')
        self.control_widget = widgets.VBox()
        self.widget = widgets.VBox(children=[self.control_widget, self.figure])
        self.create_tools()
Exemplo n.º 5
0
def show_cohort(y_true, protected, backend="bqplot"):
    """Show the base rates in the cohort by protected attribute."""

    # jerry-rig confusion matrix to count prevalence in the population
    mn, fn, mp, fp = confusion_matrix(y_true, protected).ravel()
    # Now we can do two histograms:
    values = [[mp, fp], [mn, fn]]
    colors = ["green", "red"]

    if backend == "bqplot":
        fig = plt.figure(min_aspect_ratio=1, max_aspect_ratio=1)
        # First index is colour, second index is X

        # Note - putting negative does cool weird stuff
        bars = plt.bar(
            x_ticks,
            values,
            colors=colors,
            # display_legend=False,
            # labels=["Good Customers", "Bad Customers"],
        )
        siz = "4in"
        fig.layout.width = siz
        fig.layout.height = siz

    else:
        fig, ax = mplt.subplots(figsize=(5, 5), dpi=DPI)
        ax.bar(x_ticks, values[0], color=colors[0])
        ax.bar(x_ticks, values[1], bottom=values[0], color=colors[1])
    return fig
Exemplo n.º 6
0
 def bqbars(values, colors, title, stake):
     fig = plt.figure()
     eo_bars = plt.bar(x_ticks, values, colors=colors)
     plt.ylabel("Fraction (%)")
     plt.xlabel(stake)
     fig.title = title
     return fig
Exemplo n.º 7
0
    def build_widgets(self, *args, **kwargs):
        # residuals plot
        self.residuals_fig = plt.figure(title='Residuals vs Predicted Values',
                                        layout=Layout(width='960px',
                                                      height='600px',
                                                      overflow_x='hidden',
                                                      overflow_y='hidden'))

        axes_options = {
            'y': {
                'label': 'Residual',
                'tick_format': '0.1f'
            },
            'x': {
                'label': 'Predicted Value'
            }
        }

        self.residuals_plot = plt.scatter([], [],
                                          colors=['yellow'],
                                          default_size=16,
                                          stroke='black',
                                          axes_options=axes_options)
        # zero line
        plt.hline(level=0, colors=['limegreen'], stroke_width=3)

        self.widgets_layout = HBox([self.residuals_fig])
Exemplo n.º 8
0
    def interact(self,
                 max_semiangle: float = None,
                 phi: float = 0.,
                 sliders=None,
                 throttling=False):
        import bqplot.pyplot as plt
        from abtem.visualize.bqplot import show_measurement_1d
        from abtem.visualize.widgets import quick_sliders, throttle
        import ipywidgets as widgets

        figure = plt.figure(fig_margin={
            'top': 0,
            'bottom': 50,
            'left': 50,
            'right': 0
        })
        figure.layout.height = '250px'
        figure.layout.width = '300px'

        _, callback = show_measurement_1d(
            lambda: self.profiles(max_semiangle, phi).values(), figure)

        if throttling:
            callback = throttle(throttling)(callback)

        self.changed.register(callback)

        if sliders:
            sliders = quick_sliders(self, **sliders)
            return widgets.HBox([figure, widgets.VBox(sliders)])
        else:
            return figure
Exemplo n.º 9
0
    def __init__(self, simu, net):
        """
        The initializer of the Reconstruction Workbench.

        Args:
            simu: The simulation object to simulate
            net: The network to use to reconstruct data
        """

        # We check input
        assert issubclass(type(simu), simulation.BaseSimulation)
        assert issubclass(type(net), network.BaseNetwork)

        # We instantiate arguments
        self._simulation = simu
        self._network = net
        self._sliders = [
            ipw.FloatSlider(
                min=0.,
                max=1.,
                step=0.01,
                description=self._simulation.get_factor_labels()[i])
            for i in range(self._simulation.nb_params())
        ]
        self._factors = list(np.ones(self._simulation.nb_params()))

        # We generate view elements
        self._simu_fig = plt.figure()
        self._simu_plot = plt.heatmap(
            self._simulation.draw(self._factors, depth=1))
        self._recons_fig = plt.figure()
        self._recons_plot = plt.heatmap(
            self._simulation.draw(self._factors, depth=1))
        # We create a slider with callback for each parameter
        for index, slider in enumerate(self._sliders):
            slider.observe(self._callback_closure(index), 'value')
        self._title = ipw.HTML('<h2>Reconstruction WorkBench</h2>')
        self._caption = ipw.HTML('Manipulate the generative factors:')

        #  We layout the different parts
        left_elmts = [self._title, self._caption] + self._sliders
        left_layout = ipw.Layout(padding='50px 0px 0px 0px')
        left_pane = ipw.VBox(left_elmts, layout=left_layout)
        self._simu_fig.layout = ipw.Layout(width='25%')
        self._recons_fig.layout = ipw.Layout(width='25%')
        self._layout = ipw.HBox([left_pane, self._simu_fig, self._recons_fig])
        IPython.display.display(self._layout)
def LoadEvaluationPlot(result, prediction_size, conjunto, modelo):
    """Prepara debajo de la celda una gráfica que detalla visualmente el resultado
    de la evaluación de un conjunto de datos utilizando un modelo específico.
    
    Args:
        result (:obj: `TestResult`): instancia de TestResult con los resultados
            de la evaluación.
        prediction_size (int): número de años predichos
        conjunto (str): conjunto de datos evaluado.
        modelo (str): método de predicción utilizado.
    """
    n = int(result.Y.shape[0] / prediction_size)
    m = max(np.amax(result.Y_hat), np.amax(result.Y))

    colors = []
    if prediction_size >= 1:
        colors.append(1)
    if prediction_size >= 2:
        colors.append(2)
    if prediction_size >= 3:
        colors.append(3)
    if prediction_size >= 4:
        colors.append(4)
    if prediction_size >= 5:
        colors.append(5)

    colors = np.array(colors * n)

    fig = plt.figure(
        title='Proyección de matrícula escolar en %s utilizando %s' %
        (conjunto, modelo),
        legend_location='top-left',
        fig_margin=dict(top=50, bottom=70, left=100, right=100))

    plt.scales(
        scales={
            'color':
            OrdinalColorScale(
                colors=['Green', 'DodgerBlue', 'Yellow', 'Orange', 'Red'])
        })

    axes_options = {
        'x': dict(label='Valor real (alumnos)'),
        'y': dict(label='Valor predicho (alumnos)'),
        'color': dict(label='Año', side='right')
    }

    scatter2 = plt.scatter(result.Y,
                           result.Y_hat,
                           color=colors,
                           stroke='black',
                           axes_options=axes_options)

    plt.plot(x=np.array([0, m]),
             y=np.array([0, m]),
             labels=['Línea base de predicción'],
             display_legend=True)
Exemplo n.º 11
0
 def test_figure(self):
     size = 100
     scale = 100.0
     np.random.seed(0)
     x_data = np.arange(size)
     y_data = np.cumsum(np.random.randn(size) * scale)
     fig = plt.figure(title='First Example')
     plt.plot(y_data)
     fig.save_png()
Exemplo n.º 12
0
    def create_charts(self):
        self.epoch_slider = IntSlider(description='Epoch:',
                                      min=1,
                                      max=self.num_epochs,
                                      value=1)
        self.mode_dd = Dropdown(
            description='View',
            options=['Weights', 'Gradients', 'Activations'],
            value='Weights')
        self.update_btn = Button(description='Update')

        self.bar_figure = plt.figure()
        self.bar_plot = plt.bar([], [], scales={'x': OrdinalScale()})

        self.hist_figure = plt.figure(title='Histogram of Activations')
        self.hist_plot = plt.hist([], bins=20)

        self.controls = HBox(
            [self.epoch_slider, self.mode_dd, self.update_btn])
        self.graph.tooltip = self.bar_figure
Exemplo n.º 13
0
    def _make_plot(self):
        plt.close(1)
        fig_margin = {'top': 25, 'bottom': 35, 'left': 35, 'right':25}
        fig_layout = {'height': '100%', 'width': '100%' }
        layout_args = {'fig_margin': fig_margin, 'layout': fig_layout,
                       'max_aspect_ratio': 1.618}
        self.voigt_fig = plt.figure(1, title='Voigt profile', **layout_args)
        self.voigt_plot = plt.plot(self.freq, self.h, scales={'y': LogScale()})
        plt.xlabel("Δν / ΔνD")

        plt.close(2)
        self.abs_fig = plt.figure(2, title='(αᶜ + αˡ) / α₅₀₀', **layout_args)
        self.abs_plot = plt.plot(self.freq, self.xq, scales={'y': LogScale()})
        plt.xlabel("Δν / ΔνD")

        plt.close(3)
        self.int_fig = plt.figure(3, title='Intensity', **layout_args)
        self.int_plot = plt.plot(self.freq, self.prof, scales={'y': LogScale()})
        plt.xlabel("Δν / ΔνD")

        plt.close(4)
        self.source_fig = plt.figure(4, title='Source Function', **layout_args)
        self.source_plot = plt.plot(np.log10(self.tau500), self.source_function,
                                    scales={'y': LogScale()})
        plt.xlabel("lg(τ₅₀₀)")
        self.tau_labels = plt.label(['τᶜ = 1', 'τˡ = 1'], colors=['black'],
                                    x=np.array([np.log10(self.tau500_cont),
                                                np.log10(self.tau500_line)]),
                                    y=np.array([self.source_function_cont,
                                                self.source_function_line]),
                                    y_offset=-25, align='middle')
        self.tau_line_plot = plt.plot(np.array([np.log10(self.tau500_line),
                                                np.log10(self.tau500_line)]),
                                      np.array([self.source_function_line / 1.5,
                                                self.source_function_line * 1.5]),
                                      colors=['black'])
        self.tau_cont_plot = plt.plot(np.array([np.log10(self.tau500_cont),
                                                np.log10(self.tau500_cont)]),
                                      np.array([self.source_function_cont / 1.5,
                                                self.source_function_cont * 1.5]),
                                      colors=['black'])
Exemplo n.º 14
0
def show_measurement_1d(measurements_or_func,
                        figure=None,
                        throttling=False,
                        **kwargs):
    if figure is None:
        figure = plt.figure(fig_margin={
            'top': 0,
            'bottom': 50,
            'left': 50,
            'right': 0
        })

        figure.layout.height = '250px'
        figure.layout.width = '300px'

    try:
        measurements = measurements_or_func()
        return_callback = True

    except TypeError:
        measurements = measurements_or_func
        return_callback = False

    lines = []
    for measurement, color in zip(measurements, TABLEAU_COLORS.values()):
        calibration = measurement.calibrations[0]
        array = measurement.array
        x = np.linspace(calibration.offset,
                        calibration.offset + len(array) * calibration.sampling,
                        len(array))

        line = plt.plot(x, array, colors=[color], **kwargs)
        figure.axes[0].label = format_label(measurement.calibrations[0])
        lines.append(line)

    # figure.axes[1].label = format_label(measurement)

    if return_callback:

        @throttle(throttling)
        def callback(*args, **kwargs):
            for line, measurement in zip(lines, measurements_or_func()):
                x = np.linspace(
                    calibration.offset,
                    calibration.offset + len(array) * calibration.sampling,
                    len(array))
                line.x = x
                line.y = measurement.array

        return figure, callback
    else:
        return figure
Exemplo n.º 15
0
def generate_pie_chart(df, title="", show_decimal=False):
    fig = plt.figure(title=title)

    pie_chart = plt.pie(
        sizes=df.values.tolist(),
        labels=df.index.values.tolist(),
        display_labels="outside",
        colors=chart_colors[:df.index.values.size],
        display_values=True,
    )
    if not show_decimal:
        pie_chart.values_format = "0"
    return fig
Exemplo n.º 16
0
def addTimeseries(inMap,path,bands,new_plot):
        px_series = xr.open_dataset(path)
        date_start = px_series.t.min().values
        date_end = px_series.t.max().values
        color = ['blue','red','green','yellow']

        for i,b in enumerate(bands):
            x_data = px_series.t.values
            y_data = px_series.to_array().loc[dict(variable=b)][:,0,0].values
            x_data = x_data[~np.isnan(y_data)]
            y_data = y_data[~np.isnan(y_data)]
            x_data = x_data[y_data!=0]
            y_data = y_data[y_data!=0]
            axes_options = {'x': {'label':'Time', 'side':'bottom', 'num_ticks':8, 'tick_format':'%b %y'}, 'y': {'orientation':'vertical', 'side':'left', 'num_ticks':10}}
            if i==0:
                title = ''
                for x in bands:
                    title += (x + ' ')
                title += ' timeseries'
                if new_plot:
                    inMap.figure = bqplt.figure(title=title,layout={'max_height': '250px', 'width': '600px'})
                else:
                    if inMap.figure is not None:
                        pass
                    else:
                        inMap.figure = bqplt.figure(title=title,layout={'max_height': '250px', 'width': '600px'})
                        

            scatt = bqplt.scatter(x_data, y_data, labels=[b], display_legend=True, colors=[color[i]], default_size=30, axes_options=axes_options)
        
        widget_control = WidgetControl(widget=inMap.figure, position='bottomright')
        if inMap.figure_widget is not None:
            inMap.map.remove_control(inMap.figure_widget)
        inMap.figure_widget = widget_control
        inMap.map.add_control(inMap.figure_widget)
        return
Exemplo n.º 17
0
def generate_group_bar(df, title="", scientific_notation=False):
    fig = plt.figure(title=title)
    bar_chart = plt.bar(
        x=df.columns.values.tolist(),
        y=df,
        labels=df.index.values.tolist(),
        display_legend=False,
        type="grouped",
        colors=chart_colors[:df.index.values.size],
    )
    if df.columns.name:
        plt.xlabel(df.columns.name.rsplit(" ", 1)[0])
    plt.ylim(0, np.amax(df.values))
    if not scientific_notation:
        fig.axes[1].tick_format = ".1f"
    return fig
Exemplo n.º 18
0
    def __init__(self, simu, net, emb_index=0):
        """
        The initializer of the Latent Embedding Workbench.

        Args:
            simu: The simulation object to simulate
            net: The NN trained to extract embedding
        """

        # We check input
        assert issubclass(type(simu), simulation.BaseSimulation)
        assert issubclass(type(net), network.BaseNetwork)

        # We instantiate arguments
        self._simulation = simu
        self._emb_index = emb_index
        self._network = net
        self._sliders = [
            ipw.FloatSlider(
                min=0.,
                max=1.,
                step=0.01,
                description=self._simulation.get_factor_labels()[i])
            for i in range(self._simulation.nb_params())
        ]
        self._factors = list(np.ones(self._simulation.nb_params()))
        self._latent_serie = np.zeros(
            [self._network.get_latent_size()[-1], NB_SAMPLES])

        # We generate view elements
        self._latent_fig = plt.figure()
        self._latent_plot = plt.plot(range(NB_SAMPLES), self._latent_serie)

        # We create a slider with callback for each parameter
        for index, slider in enumerate(self._sliders):
            slider.observe(self._callback_closure(index), 'value')
        self._title = ipw.HTML('<h2>Latent Space WorkBench</h2>')
        self._caption = ipw.HTML('Manipulate the generative factors:')

        #  We layout the different parts
        left_elmts = [self._title, self._caption] + self._sliders
        left_layout = ipw.Layout(padding='50px 0px 0px 0px')
        left_pane = ipw.VBox(left_elmts, layout=left_layout)
        self._latent_fig.layout = ipw.Layout(width='50%')
        self._layout = ipw.HBox([left_pane, self._latent_fig])
        IPython.display.display(self._layout)
Exemplo n.º 19
0
 def create_figure(self):
     cc = self.coords
     tree = self.tree
     fig = plt.figure(figsize=(8,0.5))
     lines = OrderedDict()
     for n in tree.nodes:
         c0 = cc[n]
         for nn in tree.children(n):
             c1 = cc[nn]
             l = plt.plot([c0[0],c1[0]],[c0[1],c1[1]],'b')
             lines[(n,nn)] = l
     xv = [c[0] for c in (cc).values()]
     yv = [c[1] for c in (cc).values()]
     points = plt.plot(xv,yv,'o')
     self.fig = fig
     self.lines = lines
     self.points = points
Exemplo n.º 20
0
def generate_bar(df, title="", scientific_notation=False, small_xlabel=False):
    fig = plt.figure(title=title)
    x_vals = df.index.values.tolist()
    if len(x_vals) > 5:
        small_xlabel = True
    x_titles = []
    for val in x_vals:
        if len(val.split(" ")) < 3:
            x_titles.append(val)
        else:
            x_titles.append(" ".join(val.split(" ")[:2]))
    bar_chart = plt.bar(x=x_titles,
                        y=df,
                        colors=chart_colors[:df.index.values.size])
    if small_xlabel:
        fig.axes[0].tick_style = {"font-size": "6"}
    if not scientific_notation:
        fig.axes[1].tick_format = ".1f"
    return fig
Exemplo n.º 21
0
 def __init__(self,
              animation_duration=100,
              aspect_ratio=1,
              tail_len=1000,
              lim=2):
     axes_options = {'x': {'label': 'x'}, 'y': {'label': 'y'}}
     self.fig = plt.figure(animation_duration=animation_duration,
                           min_aspect_ratio=aspect_ratio,
                           max_aspect_ratio=aspect_ratio)
     self.line = plt.plot([], [],
                          marker_str='b-',
                          axes_options=axes_options)
     self.hline = plt.hline(0,
                            opacities=[0],
                            colors=['red'],
                            stroke_width=3)
     self.scat = plt.plot([], [], 'ro')
     self.tail_len = tail_len
     plt.xlim(-lim, lim)
     plt.ylim(-lim, lim)
Exemplo n.º 22
0
    def __init__(self, x_range = (0,0), y_range = None,
                 width=450,
                 height=300,
                 x_label=None, y_label=None,
                 x_axis_type=None, y_axis_type=None):

        self.fig = plt.figure()
        self.

        self.fig, self.ax = plt.subplots(figsize = (int(width/75), int(height/75)))
        self.ax.set_xlim(*x_range)
        if y_range is not None:
            self.ax.set_ylim(*y_range)
        if x_label is not None:
            self.ax.set_xlabel(x_label)
        if y_label is not None:
            self.ax.set_ylabel(y_label)
        self.ax.grid(False)
        if x_axis_type is not None:
            self.ax.set_xscale(x_axis_type)
        if y_axis_type is not None:
            self.ax.set_yscale(y_axis_type)
Exemplo n.º 23
0
def feature_byProperty(features, xProperties, seriesProperty, **kwargs):
    """Generates a Chart from a set of features. Plots property values of one or more features.
    Reference: https://developers.google.com/earth-engine/guides/charts_feature#uichartfeaturebyproperty

    Args:
        features (ee.FeatureCollection): The features to include in the chart.
        xProperties (list | dict): One of (1) a list of properties to be plotted on the x-axis; or (2) a (property, label) dictionary specifying labels for properties to be used as values on the x-axis.
        seriesProperty (str): The name of the property used to label each feature in the legend.

    Raises:
        Exception: If the provided xProperties is not a list or dict.
        Exception: If the chart fails to create.
    """
    try:
        df = ee_to_df(features)

        if isinstance(xProperties, list):
            x_data = xProperties
            y_data = df[xProperties].values
        elif isinstance(xProperties, dict):
            x_data = list(xProperties.values())
            y_data = df[list(xProperties.keys())].values
        else:
            raise Exception("xProperties must be a list or dictionary.")

        labels = list(df[seriesProperty])

        if "ylim" in kwargs:
            min_value = kwargs["ylim"][0]
            max_value = kwargs["ylim"][1]
        else:
            min_value = y_data.min()
            max_value = y_data.max()
            max_value = max_value + 0.2 * (max_value - min_value)

        if "title" not in kwargs:
            title = ""
        else:
            title = kwargs["title"]
        if "legend_location" not in kwargs:
            legend_location = "top-left"
        else:
            legend_location = kwargs["legend_location"]

        if "display_legend" not in kwargs:
            display_legend = True
        else:
            display_legend = kwargs["display_legend"]

        fig = plt.figure(
            title=title,
            legend_location=legend_location,
        )

        if "width" in kwargs:
            fig.layout.width = kwargs["width"]
        if "height" in kwargs:
            fig.layout.height = kwargs["height"]

        bar_chart = plt.bar(x=x_data,
                            y=y_data,
                            labels=labels,
                            display_legend=display_legend)

        bar_chart.type = "grouped"

        if "colors" in kwargs:
            bar_chart.colors = kwargs["colors"]

        if "xlabel" in kwargs:
            plt.xlabel(kwargs["xlabel"])
        if "ylabel" in kwargs:
            plt.ylabel(kwargs["ylabel"])
        plt.ylim(min_value, max_value)

        if "xlabel" in kwargs and ("ylabel" in kwargs):
            bar_chart.tooltip = Tooltip(
                fields=["x", "y"], labels=[kwargs["xlabel"], kwargs["ylabel"]])
        else:
            bar_chart.tooltip = Tooltip(fields=["x", "y"])

        plt.show()

    except Exception as e:
        raise Exception(e)
Exemplo n.º 24
0
def feature_histogram(features,
                      property,
                      maxBuckets=None,
                      minBucketWidth=None,
                      **kwargs):
    """
    Generates a Chart from a set of features.
    Computes and plots a histogram of the given property.
    - X-axis = Histogram buckets (of property value).
    - Y-axis = Frequency

    Reference:
    https://developers.google.com/earth-engine/guides/charts_feature#uichartfeaturehistogram

    Args:
        features  (ee.FeatureCollection): The features to include in the chart.
        property                   (str): The name of the property to generate the histogram for.
        maxBuckets       (int, optional): The maximum number of buckets (bins) to use when building a histogram;
                                          will be rounded up to a power of 2.
        minBucketWidth (float, optional): The minimum histogram bucket width, or null to allow any power of 2.

    Raises:
        Exception: If the provided xProperties is not a list or dict.
        Exception: If the chart fails to create.
    """
    import math

    if not isinstance(features, ee.FeatureCollection):
        raise Exception("features must be an ee.FeatureCollection")

    first = features.first()
    props = first.propertyNames().getInfo()
    if property not in props:
        raise Exception(
            f"property {property} not found. Available properties: {', '.join(props)}"
        )

    def nextPowerOf2(n):
        return pow(2, math.ceil(math.log2(n)))

    def grow_bin(bin_size, ref):
        while bin_size < ref:
            bin_size *= 2
        return bin_size

    try:

        raw_data = pd.to_numeric(
            pd.Series(features.aggregate_array(property).getInfo()))
        y_data = raw_data.tolist()

        if "ylim" in kwargs:
            min_value = kwargs["ylim"][0]
            max_value = kwargs["ylim"][1]
        else:
            min_value = raw_data.min()
            max_value = raw_data.max()

        data_range = max_value - min_value

        if not maxBuckets:
            initial_bin_size = nextPowerOf2(data_range / pow(2, 8))
            if minBucketWidth:
                if minBucketWidth < initial_bin_size:
                    bin_size = grow_bin(minBucketWidth, initial_bin_size)
                else:
                    bin_size = minBucketWidth
            else:
                bin_size = initial_bin_size
        else:
            initial_bin_size = math.ceil(data_range / nextPowerOf2(maxBuckets))
            if minBucketWidth:
                if minBucketWidth < initial_bin_size:
                    bin_size = grow_bin(minBucketWidth, initial_bin_size)
                else:
                    bin_size = minBucketWidth
            else:
                bin_size = initial_bin_size

        start_bins = (math.floor(min_value / bin_size) *
                      bin_size) - (bin_size / 2)
        end_bins = (math.ceil(max_value / bin_size) * bin_size) + (bin_size /
                                                                   2)

        if start_bins < min_value:
            y_data.append(start_bins)
        else:
            y_data[y_data.index(min_value)] = start_bins
        if end_bins > max_value:
            y_data.append(end_bins)
        else:
            y_data[y_data.index(max_value)] = end_bins

        num_bins = math.floor((end_bins - start_bins) / bin_size)

        if "title" not in kwargs:
            title = ""
        else:
            title = kwargs["title"]

        fig = plt.figure(title=title)

        if "width" in kwargs:
            fig.layout.width = kwargs["width"]
        if "height" in kwargs:
            fig.layout.height = kwargs["height"]

        if "xlabel" not in kwargs:
            xlabel = ""
        else:
            xlabel = kwargs["xlabel"]

        if "ylabel" not in kwargs:
            ylabel = ""
        else:
            ylabel = kwargs["ylabel"]

        histogram = plt.hist(
            sample=y_data,
            bins=num_bins,
            axes_options={
                "count": {
                    "label": ylabel
                },
                "sample": {
                    "label": xlabel
                }
            },
        )

        if "colors" in kwargs:
            histogram.colors = kwargs["colors"]
        if "stroke" in kwargs:
            histogram.stroke = kwargs["stroke"]
        else:
            histogram.stroke = "#ffffff00"
        if "stroke_width" in kwargs:
            histogram.stroke_width = kwargs["stroke_width"]
        else:
            histogram.stroke_width = 0

        if ("xlabel" in kwargs) and ("ylabel" in kwargs):
            histogram.tooltip = Tooltip(
                fields=["midpoint", "count"],
                labels=[kwargs["xlabel"], kwargs["ylabel"]],
            )
        else:
            histogram.tooltip = Tooltip(fields=["midpoint", "count"])
        plt.show()

    except Exception as e:
        raise Exception(e)
Exemplo n.º 25
0
def feature_groups(features, xProperty, yProperty, seriesProperty, **kwargs):
    """Generates a Chart from a set of features.
    Plots the value of one property for each feature.
    Reference:
    https://developers.google.com/earth-engine/guides/charts_feature#uichartfeaturegroups
    Args:
        features (ee.FeatureCollection): The feature collection to make a chart from.
        xProperty (str): Features labeled by xProperty.
        yProperty (str): Features labeled by yProperty.
        seriesProperty (str): The property used to label each feature in the legend.
    Raises:
        Exception: Errors when creating the chart.
    """

    try:
        df = ee_to_df(features)
        df[yProperty] = pd.to_numeric(df[yProperty])
        unique_series_values = df[seriesProperty].unique().tolist()
        new_column_names = []

        for value in unique_series_values:
            sample_filter = (df[seriesProperty] == value).map({
                True: 1,
                False: 0
            })
            column_name = str(yProperty) + "_" + str(value)
            df[column_name] = df[yProperty] * sample_filter
            new_column_names.append(column_name)

        if "labels" in kwargs:
            labels = kwargs["labels"]
        else:
            labels = [str(x) for x in unique_series_values]

        if "ylim" in kwargs:
            min_value = kwargs["ylim"][0]
            max_value = kwargs["ylim"][1]
        else:
            min_value = df[yProperty].to_numpy().min()
            max_value = df[yProperty].to_numpy().max()
            max_value = max_value + 0.2 * (max_value - min_value)

        if "title" not in kwargs:
            title = ""
        else:
            title = kwargs["title"]
        if "legend_location" not in kwargs:
            legend_location = "top-left"
        else:
            legend_location = kwargs["legend_location"]

        x_data = list(df[xProperty])
        y_data = [df[x] for x in new_column_names]

        plt.bar(x_data, y_data)
        fig = plt.figure(
            title=title,
            legend_location=legend_location,
        )

        if "width" in kwargs:
            fig.layout.width = kwargs["width"]
        if "height" in kwargs:
            fig.layout.height = kwargs["height"]

        if "display_legend" not in kwargs:
            display_legend = True
        else:
            display_legend = kwargs["display_legend"]

        bar_chart = plt.bar(x_data,
                            y_data,
                            labels=labels,
                            display_legend=display_legend)

        if "colors" in kwargs:
            bar_chart.colors = kwargs["colors"]

        if "xlabel" in kwargs:
            plt.xlabel(kwargs["xlabel"])
        if "ylabel" in kwargs:
            plt.ylabel(kwargs["ylabel"])
        plt.ylim(min_value, max_value)

        if "xlabel" in kwargs and ("ylabel" in kwargs):
            bar_chart.tooltip = Tooltip(
                fields=["x", "y"], labels=[kwargs["xlabel"], kwargs["ylabel"]])
        else:
            bar_chart.tooltip = Tooltip(fields=["x", "y"])

        plt.show()

    except Exception as e:
        raise Exception(e)
Exemplo n.º 26
0
def feature_byFeature(features, xProperty, yProperties, **kwargs):
    """Generates a Chart from a set of features. Plots the value of one or more properties for each feature.
    Reference: https://developers.google.com/earth-engine/guides/charts_feature#uichartfeaturebyfeature

    Args:
        features (ee.FeatureCollection): The feature collection to generate a chart from.
        xProperty (str): Features labeled by xProperty.
        yProperties (list): Values of yProperties.

    Raises:
        Exception: Errors when creating the chart.
    """

    try:

        df = ee_to_df(features)
        if "ylim" in kwargs:
            min_value = kwargs["ylim"][0]
            max_value = kwargs["ylim"][1]
        else:
            min_value = df[yProperties].to_numpy().min()
            max_value = df[yProperties].to_numpy().max()
            max_value = max_value + 0.2 * (max_value - min_value)

        if "title" not in kwargs:
            title = ""
        else:
            title = kwargs["title"]
        if "legend_location" not in kwargs:
            legend_location = "top-left"
        else:
            legend_location = kwargs["legend_location"]

        x_data = list(df[xProperty])
        y_data = df[yProperties].values.T.tolist()

        plt.bar(x_data, y_data)
        fig = plt.figure(
            title=title,
            legend_location=legend_location,
        )

        if "width" in kwargs:
            fig.layout.width = kwargs["width"]
        if "height" in kwargs:
            fig.layout.height = kwargs["height"]

        if "labels" in kwargs:
            labels = kwargs["labels"]
        else:
            labels = yProperties

        if "display_legend" not in kwargs:
            display_legend = True
        else:
            display_legend = kwargs["display_legend"]

        bar_chart = plt.bar(x_data,
                            y_data,
                            labels=labels,
                            display_legend=display_legend)

        bar_chart.type = "grouped"

        if "colors" in kwargs:
            bar_chart.colors = kwargs["colors"]

        if "xlabel" in kwargs:
            plt.xlabel(kwargs["xlabel"])
        if "ylabel" in kwargs:
            plt.ylabel(kwargs["ylabel"])
        plt.ylim(min_value, max_value)

        if "xlabel" in kwargs and ("ylabel" in kwargs):
            bar_chart.tooltip = Tooltip(
                fields=["x", "y"], labels=[kwargs["xlabel"], kwargs["ylabel"]])
        else:
            bar_chart.tooltip = Tooltip(fields=["x", "y"])

        plt.show()

    except Exception as e:
        raise Exception(e)
from __future__ import print_function
from bqplot import pyplot as plt
from bqplot import topo_load
from bqplot.interacts import panzoom
from numpy import *
import pandas as pd

random.seed(0)
size = 100
y_data = cumsum(random.randn(size) * 100.0)
y_data_2 = cumsum(random.randn(size))
y_data_3 = cumsum(random.randn(size) * 100.)

plt.figure(1)
n = 100
x = linspace(0.0, 10.0, n)
plt.plot(x, y_data, axes_options={'y': {'grid_lines': 'dashed'}})
plt.show()
Exemplo n.º 28
0
    def vue_do_aper_phot(self, *args, **kwargs):
        if self._selected_data is None or self._selected_subset is None:
            self.result_available = False
            self.results = []
            self.plot_available = False
            self.radial_plot = ''
            self.hub.broadcast(SnackbarMessage(
                "No data for aperture photometry", color='error', sender=self))
            return

        data = self._selected_data
        reg = self._selected_subset

        try:
            comp = data.get_component(data.main_components[0])
            try:
                bg = float(self.background_value)
            except ValueError:  # Clearer error message
                raise ValueError('Missing or invalid background value')
            comp_no_bg = comp.data - bg

            # TODO: Use photutils when it supports astropy regions.
            if not isinstance(reg, RectanglePixelRegion):
                aper_mask = reg.to_mask(mode='exact')
            else:
                # TODO: https://github.com/astropy/regions/issues/404 (moot if we use photutils?)
                aper_mask = reg.to_mask(mode='subpixels', subpixels=32)
            npix = np.sum(aper_mask) * u.pix
            img = aper_mask.get_values(comp_no_bg, mask=None)
            aper_mask_stat = reg.to_mask(mode='center')
            comp_no_bg_cutout = aper_mask_stat.cutout(comp_no_bg)
            img_stat = aper_mask_stat.get_values(comp_no_bg, mask=None)
            include_pixarea_fac = False
            include_counts_fac = False
            include_flux_scale = False
            if comp.units:
                img_unit = u.Unit(comp.units)
                img = img * img_unit
                img_stat = img_stat * img_unit
                bg = bg * img_unit
                comp_no_bg_cutout = comp_no_bg_cutout * img_unit
                if u.sr in img_unit.bases:  # TODO: Better way to detect surface brightness unit?
                    try:
                        pixarea = float(self.pixel_area)
                    except ValueError:  # Clearer error message
                        raise ValueError('Missing or invalid pixel area')
                    if not np.allclose(pixarea, 0):
                        include_pixarea_fac = True
                if img_unit != u.count:
                    try:
                        ctfac = float(self.counts_factor)
                    except ValueError:  # Clearer error message
                        raise ValueError('Missing or invalid counts conversion factor')
                    if not np.allclose(ctfac, 0):
                        include_counts_fac = True
                try:
                    flux_scale = float(self.flux_scaling)
                except ValueError:  # Clearer error message
                    raise ValueError('Missing or invalid flux scaling')
                if not np.allclose(flux_scale, 0):
                    include_flux_scale = True
            rawsum = np.nansum(img)
            d = {'id': 1,
                 'xcenter': reg.center.x * u.pix,
                 'ycenter': reg.center.y * u.pix}
            if data.coords is not None:
                d['sky_center'] = data.coords.pixel_to_world(reg.center.x, reg.center.y)
            else:
                d['sky_center'] = None
            d.update({'background': bg,
                      'npix': npix})
            if include_pixarea_fac:
                pixarea = pixarea * (u.arcsec * u.arcsec / u.pix)
                pixarea_fac = npix * pixarea.to(u.sr / u.pix)
                d.update({'aperture_sum': rawsum * pixarea_fac,
                          'pixarea_tot': pixarea_fac})
            else:
                d.update({'aperture_sum': rawsum,
                          'pixarea_tot': None})
            if include_counts_fac:
                ctfac = ctfac * (rawsum.unit / u.count)
                sum_ct = rawsum / ctfac
                d.update({'aperture_sum_counts': sum_ct,
                          'aperture_sum_counts_err': np.sqrt(sum_ct.value) * sum_ct.unit,
                          'counts_fac': ctfac})
            else:
                d.update({'aperture_sum_counts': None,
                          'aperture_sum_counts_err': None,
                          'counts_fac': None})
            if include_flux_scale:
                flux_scale = flux_scale * rawsum.unit
                d.update({'aperture_sum_mag': -2.5 * np.log10(rawsum / flux_scale) * u.mag,
                          'flux_scaling': flux_scale})
            else:
                d.update({'aperture_sum_mag': None,
                          'flux_scaling': None})

            # Extra stats beyond photutils.
            d.update({'mean': np.nanmean(img_stat),
                      'stddev': np.nanstd(img_stat),
                      'median': np.nanmedian(img_stat),
                      'min': np.nanmin(img_stat),
                      'max': np.nanmax(img_stat),
                      'data_label': data.label,
                      'subset_label': reg.meta.get('label', ''),
                      'timestamp': Time(datetime.utcnow())})

            # Attach to app for Python extraction.
            if (not hasattr(self.app, '_aper_phot_results') or
                    not isinstance(self.app._aper_phot_results, QTable)):
                self.app._aper_phot_results = _qtable_from_dict(d)
            else:
                try:
                    d['id'] = self.app._aper_phot_results['id'].max() + 1
                    self.app._aper_phot_results.add_row(d.values())
                except Exception:  # Discard incompatible QTable
                    d['id'] = 1
                    self.app._aper_phot_results = _qtable_from_dict(d)

            # Radial profile
            reg_bb = reg.bounding_box
            reg_ogrid = np.ogrid[reg_bb.iymin:reg_bb.iymax, reg_bb.ixmin:reg_bb.ixmax]
            radial_dx = reg_ogrid[1] - reg.center.x
            radial_dy = reg_ogrid[0] - reg.center.y
            radial_r = np.hypot(radial_dx, radial_dy).ravel()  # pix
            radial_img = comp_no_bg_cutout.ravel()
            if comp.units:
                y_data = radial_img.value
                y_label = radial_img.unit.to_string()
            else:
                y_data = radial_img
                y_label = 'Value'
            bqplt.clear()
            # NOTE: default margin in bqplot is 60 in all directions
            fig = bqplt.figure(1, title='Radial profile from Subset center',
                               fig_margin={'top': 60, 'bottom': 60, 'left': 40, 'right': 10},
                               title_style={'font-size': '12px'})  # TODO: Jenn wants title at bottom. # noqa
            bqplt.plot(radial_r, y_data, 'go', figure=fig, default_size=1)
            bqplt.xlabel(label='pix', mark=fig.marks[-1], figure=fig)
            bqplt.ylabel(label=y_label, mark=fig.marks[-1], figure=fig)

        except Exception as e:  # pragma: no cover
            self.result_available = False
            self.results = []
            self.plot_available = False
            self.radial_plot = ''
            self.hub.broadcast(SnackbarMessage(
                f"Aperture photometry failed: {repr(e)}", color='error', sender=self))

        else:
            # Parse results for GUI.
            tmp = []
            for key, x in d.items():
                if key in ('id', 'data_label', 'subset_label', 'background', 'pixarea_tot',
                           'counts_fac', 'aperture_sum_counts_err', 'flux_scaling', 'timestamp'):
                    continue
                if (isinstance(x, (int, float, u.Quantity)) and
                        key not in ('xcenter', 'ycenter', 'sky_center', 'npix',
                                    'aperture_sum_counts')):
                    x = f'{x:.4e}'
                    tmp.append({'function': key, 'result': x})
                elif key == 'sky_center' and x is not None:
                    tmp.append({'function': 'RA center', 'result': f'{x.ra.deg:.4f} deg'})
                    tmp.append({'function': 'Dec center', 'result': f'{x.dec.deg:.4f} deg'})
                elif key in ('xcenter', 'ycenter', 'npix'):
                    x = f'{x:.1f}'
                    tmp.append({'function': key, 'result': x})
                elif key == 'aperture_sum_counts' and x is not None:
                    x = f'{x:.4e} ({d["aperture_sum_counts_err"]:.4e})'
                    tmp.append({'function': key, 'result': x})
                elif not isinstance(x, str):
                    x = str(x)
                    tmp.append({'function': key, 'result': x})
            self.results = tmp
            self.result_available = True
            self.radial_plot = fig
            self.bqplot_figs_resize = [fig]
            self.plot_available = True
Exemplo n.º 29
0
import numpy as np
import bqplot.pyplot as plt

size = 100

plt.figure(title="Scatter plot with colors")
plt.scatter(np.random.randn(size), np.random.randn(size), color=np.random.randn(size))
plt.show()
#!/usr/bin/env python
# coding: utf-8

from bqplot import pyplot as plt
import ipyvuetify as v
import ipywidgets as widgets
import numpy as np

# generate some fake data
np.random.seed(0)
n = 2000
x = np.linspace(0.0, 10.0, n)
y = np.cumsum(np.random.randn(n)*10).astype(int)

# create a bqplot figure
fig_hist = plt.figure(title='Histogram')
hist = plt.hist(y, bins=25)

# slider
slider = v.Slider(thumb_label='always', class_="px-4", v_model=30)
widgets.link((slider, 'v_model'), (hist, 'bins'))

fig_lines = plt.figure( title='Line Chart')
lines = plt.plot(x, y)

# even handling
selector = plt.brush_int_selector()
def update_range(*ignore):
    if selector.selected is not None and len(selector.selected) == 2:
        xmin, xmax = selector.selected
        mask = (x > xmin) & (x < xmax)
Exemplo n.º 31
0
def plot_profitability_distributions(scenarios, backend):
    for id in scenarios:
        # Make scenario data
        scenario = scenarios[id]
        desc = scenario.description

        train_data = scenario.gendata("train")
        test_data = scenario.gendata("test")
        np.random.seed(42)
        X_train, y_train = scenario.get_modelling_data(train_data)
        X_test, y_test = scenario.get_modelling_data(test_data)
        data = {
            'training': (X_train, y_train, train_data.disadv_flag),
            'deployment': (X_test, y_test, test_data.disadv_flag),
        }

        figure_name = f"Profitability_{id}_{desc}"
        print(f"for Scenario {id}: {desc}")

        if backend == "matplotlib":
            l = len(data.items())
            fig, ax = mplt.subplots(1, l, figsize=(5 * l, 5), dpi=DPI)
            i = 0
        else:
            plts = []

        for cohort, (X, y_true, prot) in data.items():
            dis = prot.astype(int)
            mn, fn, mp, fp = confusion_matrix(y_true, dis).ravel()
            values = [[mp, fp], [mn, fn]]
            # half way between selected and not
            colors = ["#ff4045", "#48B748"]
            title = f"Representation in {cohort} cohort."
            ylabel = "Number of Applicants"

            if backend == "matplotlib":
                bars = mplt_bars(ax[i], scenario.ticks, values, colors, ylabel,
                                 title)
                mplt.legend(
                    bars,
                    ["profitable customers", "non-profitable customers"],
                    bbox_to_anchor=(1.05, 0.5)  # sit outside plot...
                )
                i += 1

            else:
                fig = plt.figure(min_aspect_ratio=1, max_aspect_ratio=1)
                # First index is colour, second index is X

                # Note - putting negative does cool weird stuff
                bars = plt.bar(
                    scenario.ticks,
                    values,
                    colors=colors,
                    # display_legend=False,
                    # labels=["Good Customers", "Bad Customers"],
                )
                siz = "4in"
                fig.layout.width = siz
                fig.layout.height = siz

                fig.title = title
                fig.axes[0].color = fig.axes[1].color = "Black"
                plt.ylabel(ylabel)
                plts.append(fig)

        if backend == "matplotlib":
            fig.savefig("images/" + figure_name + ".png",
                        bbox_inches="tight",
                        dpi=300)

        elif backend == "bqplot":
            box = widgets.HBox(plts)
            box.layout.width = "90%"
            display(box)

        figure_name = f"Profitability_{id}_{desc}"
Exemplo n.º 32
0
def plot_feature_importances(scenarios, backend):
    # Plot all the feature importances
    for id in scenarios:
        # Make scenario data
        scenario = scenarios[id]
        desc = scenario.description
        figure_name = f"Importance_{id}_{desc}"
        train_data = scenario.gendata("train")
        test_data = scenario.gendata("test")
        np.random.seed(42)
        X_train, y_train = scenario.get_modelling_data(train_data)
        X_test, y_test = scenario.get_modelling_data(test_data)

        dis = scenario.ticks[1]
        if dis != "SE Asian proxy":  # proper noun
            dis = dis.lower()

        if "pinterest" in train_data.columns:
            ren = {
                "pinterest": scenario.proxy_name,
                'disadv_flag': dis,
            }
            # warning: train data seems to be writable in-place
            X_train = X_train.rename(columns=ren)
            X_test = X_test.rename(columns=ren)
        clf = LogisticRegression()
        clf.fit(X_train, y_train)
        columns = list(X_train.columns)
        importance = clf.coef_[0]

        if True:
            # standardize?
            importance = importance / X_train.std(axis=0)

        heights = np.abs(importance)
        cols = ['orange', 'blue']
        colors = (np.array(cols)[(importance >= 0).astype(int)]).tolist()
        title = "Model Feature Importance"

        print("Importance: {}".format(importance))

        if backend == "matplotlib":
            fig, ax = mplt.subplots(figsize=(5, 5), dpi=DPI)
            ax.bar(columns, importance, color=colors)
            ax.axhline(0, color="Black", lw=1.)

            # Add some dummies for a legend
            c = columns[0]
            mplt.bar([c], [0], color=cols[1], label="Increases Score")
            mplt.bar([c], [0], color=cols[0], label="Decreases Score")
            mplt.plot([-0.5, len(columns) - 0.5], [0, 0], 'k')

            scale = 10
            if int(id) == 5:
                # this scenario uses huge weights for some reason...
                scale = 40

            mplt.ylim(-scale, scale)
            mplt.legend(bbox_to_anchor=(1.5, 0.5))

            ax.set_title(title)
            fig.savefig("images/" + figure_name + ".png",
                        bbox_inches="tight",
                        dpi=300)

        elif backend == "bqplot":
            fig = plt.figure(title=title,
                             min_aspect_ratio=1,
                             max_aspect_ratio=1)
            for c, h, colr in zip(columns, importance, colors):
                plt.bar([c], [h], colors=[colr])  # each bar is its own bar
            plt.ylim(-10, 10)  # was -10, 10 except for scenario 5?
            fig.axes[0].color = fig.axes[1].color = "Black"
            fig.layout.width = fig.layout.height = "5in"
            display(fig)
Exemplo n.º 33
0
    def create_widget(self, output, plot, dataset, limits):
        self.plot = plot
        self.output = output
        self.dataset = dataset
        self.limits = np.array(limits).tolist()

        def fix(v):
            # bqplot is picky about float and numpy scalars
            if hasattr(v, 'item'):
                return v.item()
            else:
                return v

        self.scale_x = bqplot.LinearScale(min=fix(limits[0][0]),
                                          max=fix(limits[0][1]),
                                          allow_padding=False)
        self.scale_y = bqplot.LinearScale(min=fix(limits[1][0]),
                                          max=fix(limits[1][1]),
                                          allow_padding=False)
        self.scale_rotation = bqplot.LinearScale(min=0, max=1)
        self.scale_size = bqplot.LinearScale(min=0, max=1)
        self.scale_opacity = bqplot.LinearScale(min=0, max=1)
        self.scales = {
            'x': self.scale_x,
            'y': self.scale_y,
            'rotation': self.scale_rotation,
            'size': self.scale_size,
            'opacity': self.scale_opacity
        }

        margin = {'bottom': 35, 'left': 60, 'right': 5, 'top': 5}
        self.figure = plt.figure(self.figure_key,
                                 fig=self.figure,
                                 scales=self.scales,
                                 fig_margin=margin)
        self.figure.layout.min_width = '600px'
        plt.figure(fig=self.figure)
        self.figure.padding_y = 0
        x = np.arange(0, 10)
        y = x**2
        self._fix_scatter = s = plt.scatter(x,
                                            y,
                                            visible=False,
                                            rotation=x,
                                            scales=self.scales)
        self._fix_scatter.visible = False
        # self.scale_rotation = self.scales['rotation']
        src = ""  # vaex.image.rgba_to_url(self._create_rgb_grid())
        # self.scale_x.min, self.scale_x.max = self.limits[0]
        # self.scale_y.min, self.scale_y.max = self.limits[1]
        self.core_image = widgets.Image(format='png')
        self.core_image_fix = widgets.Image(format='png')

        self.image = bqplot.Image(scales=self.scales, image=self.core_image)
        self.figure.marks = self.figure.marks + [self.image]
        # self.figure.animation_duration = 500
        self.figure.layout.width = '100%'
        self.figure.layout.max_width = '500px'
        self.scatter = s = plt.scatter(x,
                                       y,
                                       visible=False,
                                       rotation=x,
                                       scales=self.scales,
                                       size=x,
                                       marker="arrow")
        self.panzoom = bqplot.PanZoom(scales={
            'x': [self.scale_x],
            'y': [self.scale_y]
        })
        self.figure.interaction = self.panzoom
        for axes in self.figure.axes:
            axes.grid_lines = 'none'
            axes.color = axes.grid_color = axes.label_color = blackish
        self.figure.axes[0].label = str(plot.x)
        self.figure.axes[1].label = str(plot.y)

        self.scale_x.observe(self._update_limits, "min")
        self.scale_x.observe(self._update_limits, "max")
        self.scale_y.observe(self._update_limits, "min")
        self.scale_y.observe(self._update_limits, "max")
        self.observe(self._update_scales, "limits")

        self.image.observe(self._on_view_count_change, 'view_count')
        self.control_widget = widgets.VBox()
        self.widget = widgets.VBox(children=[self.figure])
        self.create_tools()