コード例 #1
0
 def on_button_clicked(b):
     # add new point to the Montecarlo experiment
     monte.addNewPoint(dropdown_newpoints.value)
     # Update the plot
     with plot_output:
         clear_output(wait=True)
         fig = plotTheMonteCarloExperiment(monte.points)
         display( fig )
         close_fig( fig ) #close figs to liberate memory
     # Update the widget printing the value of pi
     html_output.value = html_message_template.format(pi_value = str(monte) )
コード例 #2
0
def plot_collections(shifted_data: tuple, thresh: ThreshType, rate: int,
                     location: str):
    normal, delayed = shifted_data
    low, high = thresh
    fig_format = 'pdf'  # figure_settings['format']

    linewidth = 0.01 if fig_format in ('svg', 'pdf') else 0.5

    name = f'time delay {low}-{high}Hz'
    f_name = name.replace(' ', '_')
    file_name = os_path.join(location, results_directory['general'],
                             f'{f_name}.{fig_format}')

    nrows = int(round(delayed.columns.size / 8))

    delay_size = int32(round(rate / (((low + high) / 2) * 4)))
    sns.set(style='white')

    default_cols = 8
    ncols = default_cols if delayed.columns.size > default_cols else delayed.columns.size

    fig, _axes = subplots(nrows=nrows if ncols *
                          nrows == delayed.columns.size else nrows + 1,
                          ncols=ncols,
                          figsize=[16, 20],
                          subplot_kw={
                              'aspect': 'equal',
                              'adjustable': 'box'
                          })
    fig.suptitle(f'{name.title()}\nDelay = {delay_size}', fontsize=12)
    axes = _axes.ravel()

    for ind, col in enumerate(delayed.columns):
        ax = axes[ind]
        ax.plot(normal[col], delayed[col], lw=linewidth)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_title('Channel {}'.format(col))

    subplots_adjust(top=0.95, bottom=.01, left=.01, right=.99)

    fig.savefig(file_name,
                format=fig_format,
                bbox_inches='tight',
                dpi=figure_settings['resolution'])

    close_fig(fig)

    return True
コード例 #3
0
def plot_heatmap(matrix: ArrayOrDataFrame, location: str, thresh: ThreshType,
                 kind: str):
    low, high = thresh
    fig_format = figure_settings['format']

    name = f'{kind} heatmap {low}-{high}Hz'

    f_name = name.replace(' ', '_')
    file_name = os_path.join(location, kind.replace(' ', '_'),
                             f'{f_name}.{fig_format}')

    sns.set(style='ticks')

    fig, ax = subplots()

    # col, ind = matrix.columns, matrix.index
    # matrix_arr = fill_diagonal(matrix.as_matrix(), NaN)
    # matrix = DataFrame(matrix_arr, columns=col, index=ind)

    sns.heatmap(matrix,
                square=matrix.shape[0] == matrix.shape[1],
                ax=ax,
                cmap=colormap,
                xticklabels=10,
                yticklabels=10,
                robust=True,
                cbar_kws={'shrink': .5})
    ax.set_title(name.title().replace('_', ' '), fontsize=12)

    subplots_adjust(top=0.98, bottom=.02, left=.06, right=.99)

    fig.savefig(file_name,
                format=fig_format,
                bbox_inches='tight',
                dpi=figure_settings['resolution'])

    close_fig(fig)

    return True
コード例 #4
0
def returnPebblessWidget():

    # Set an instance of the MontecarloApprox class
    monte = MontecarloApprox()

    # Set an output widget to display the plot, an HTML widget
    # to display the value of pi and buttons widget to update everything
    plot_output = ipw.Output()
    html_output = ipw.HTML()
    update_button = ipw.Button(description="Add points!")

    # A dropdown widget to choose how many new points are added when
    # update_button is clicked
    dropdown_newpoints = ipw.Dropdown(
        options=[1, 10, 100],
        value=1,
        description='New Points:',
        disabled=False,
    )

    # Add Initial plot to the plot_output
    with plot_output:
        fig = plotTheMonteCarloExperiment( [] )
        display( fig )
        close_fig( fig ) #close figs to liberate memory

    # define initial value for the html_output
    html_message_template ="""<div class="jumbotron">
    <h3>
    π ≈ {pi_value}
    </h3>
    </div>
    """
    html_output.value = html_message_template.format(pi_value = str(monte) )

    # Add a call back to the button that will update the Montecarlo Approx with a new point
    # and update the scatter plot.
    # Using a decorator
    @update_button.on_click
    def on_button_clicked(b):
        # add new point to the Montecarlo experiment
        monte.addNewPoint(dropdown_newpoints.value)
        # Update the plot
        with plot_output:
            clear_output(wait=True)
            fig = plotTheMonteCarloExperiment(monte.points)
            display( fig )
            close_fig( fig ) #close figs to liberate memory
        # Update the widget printing the value of pi
        html_output.value = html_message_template.format(pi_value = str(monte) )

    add_points_widget = ipw.VBox(
        children = [
            dropdown_newpoints,
            update_button
        ],
        layout = ipw.Layout(align_items = 'flex-end')
    )

    return ipw.HBox(
        children= [
            plot_output,
            ipw.VBox(
                children = [
                    add_points_widget,
                    html_output
                ]
            )
        ],
        layout = ipw.Layout(justify_content = 'center')
    )
コード例 #5
0
def plot_windowed_comparison(dt: DataFrame, thresh: ThreshType, rate: int,
                             location: str, columns, corr_lag):
    low, high = thresh
    primary_title = f'Moving Windows - Filter threshold {low}-{high}Hz'
    # matrix, sum_of_medians = results
    moving_mean = DataFrame(dt, columns=columns)

    # Setting up the figure -------------------------------------------------------------

    sns.set(style="ticks")
    fig = figure(figsize=[16, 16])
    fig.suptitle(primary_title, fontsize=14)

    ax1 = subplot2grid((2, 5), (0, 0), colspan=4, rowspan=1)

    # Heatmap ---------------------------------------------------------------------------

    sns.heatmap(moving_mean.T,
                cmap=colormap,
                ax=ax1,
                xticklabels=rate * 2,
                yticklabels=10,
                robust=True,
                cbar_kws={
                    'orientation': 'horizontal',
                    'shrink': 0.20,
                    'pad': 0.07
                })
    sns.despine(ax=ax1, right=True, top=True)
    ax1.set_title(
        f'Auto-correlation (lag={corr_lag}) of running window (length={SAMPLING_FREQ})',
        fontsize=12)
    ax1.set_yticklabels(ax1.get_yticklabels(), rotation=60)

    # Horizontal sum of the heatmap -----------------------------------------------------

    ax2 = subplot2grid((2, 5), (0, 4), colspan=1, rowspan=1)
    sns.heatmap(asarray([moving_mean.sum(axis=0)] * 2).T,
                cmap=colormap,
                ax=ax2,
                yticklabels=10,
                xticklabels=False,
                robust=True,
                cbar_kws={
                    'orientation': 'horizontal',
                    'pad': 0.07
                })
    sns.despine(ax=ax2, right=True, top=True)
    ax2.set_title('Sum of medians\nper channel', fontsize=12)
    ax2.set_yticklabels([])

    # Vertical sum of the heatmap -------------------------------------------------------

    ax3 = subplot2grid((4, 1), (2, 0), colspan=1, rowspan=1)
    time_delta = linspace(0,
                          max(moving_mean.shape) / rate,
                          max(moving_mean.shape))
    ax3.plot(time_delta,
             moving_mean.sum(axis=1).tolist(),
             '--',
             lw=.75,
             label='Normal',
             alpha=.5)
    ax3_se = ax3.twinx()
    ax3_se.plot(time_delta,
                abs(moving_mean).sum(axis=1).tolist(),
                lw=1,
                label='Modulus')
    ax3.legend(loc='upper left')
    ax3_se.legend(loc='upper right')

    ax3.set_title('Sum of medians per window (second)'.format(*thresh),
                  fontsize=12)
    ax3.set_xlim(time_delta.min(), time_delta.max())

    # Saving the figure -----------------------------------------------------------------

    path = os_path.join(location, results_directory['windowed'],
                        primary_title + '.' + figure_settings['format'])

    fig.savefig(path,
                format=figure_settings['format'],
                bbox_inches='tight',
                dpi=figure_settings['resolution'])

    close_fig(fig)

    return True
コード例 #6
0
def plotter(data: DataFrame,
            location: str,
            freq: int,
            thresh: ThreshType,
            save=True):
    # fig_format = 'pdf'  # figure_settings['format']
    extension = 'pdf'  # figure_settings['format']

    amplify, linewidth = .5, 0.01 if extension in ('svg', 'pdf') else 0.5

    if not save:
        linewidth = .6

    low, high = thresh
    title = f'Filtered Data {low}-{high}Hz'
    index_size = data.index.size
    index = linspace(0, index_size / freq, index_size, dtype=float32)

    index_max_squared = index.max() * 2

    sns.set(style='ticks')

    fig_width = 60 if index_max_squared > 60 else index_max_squared
    fig_height = data.columns.size / 4
    fig_kws = dict(figsize=(fig_width, fig_height))

    fig, ax = subplots(fig_kw=fig_kws)

    data_min, data_max = data.quantile(0.02).mean(), data.quantile(0.98).mean()

    dr = (data_max - data_min) * (1 / amplify)  # Amplify

    func = partial(_stack_channel,
                   index=index,
                   index_size=index_size,
                   data=data)

    # [::-1] is to display the matrix in the actual order.
    segs = map(func, data.columns[::-1])

    offsets = zeros((data.columns.size, 2), dtype=float32)
    offsets[:, 1] = arange(-1, data.columns.size - 1) * dr

    lines = LineCollection(segs,
                           offsets=offsets,
                           transOffset=None,
                           linewidth=linewidth,
                           cmap=_edged_colormap())
    lines.set_array(arange(0, data.columns.size))
    ax.add_collection(lines)

    # Display settings:
    # ----------------------------------------------------------
    ax.set_title(title, fontsize=14)

    ax.set_xlim([index.min(), index.max()])
    ax.set_ylim([data_min - dr * 2, (data.columns.size - 1) * dr + data_max])

    ax.set_yticks(offsets[:, 1])

    # [::-1] is to display the matrix in the actual order.
    ax.set_yticklabels(data.columns[::-1], fontsize=12)

    ax.set_xlabel('Time [sec]')
    ax.set_ylabel('Channels [$\mu V$]')

    subplots_adjust(top=0.98, bottom=.02, left=.08, right=.99)

    ax.set_xticks(arange(0, index.max() + 1))
    ax.xaxis.grid(b=True,
                  which='major',
                  color='lightgray',
                  linestyle=':',
                  linewidth=1,
                  alpha=.3,
                  antialiased=True)

    subplots_adjust(top=0.98, bottom=.02, left=.08, right=.99)

    sns.despine(right=True, top=True, bottom=True, left=True)

    location = os_path.join(location, results_directory['general'],
                            title.replace(' ', '_').lower() + '.' + extension)

    if save:
        fig.savefig(location,
                    format=extension,
                    bbox_inches='tight',
                    dpi=figure_settings['resolution'])

        close_fig(fig)
        return True

    return fig, ax