Exemple #1
0
def do_surveys():
    with figure("tlx_results", figsize=fig_size(0.44, 1)):
        sns.factorplot(x="experiment", y="tlx", data=tlx, kind="box")
        sns.swarmplot(x="experiment", y=r"tlx",
                      data=tlx, palette=cmap_complement, split=True)
        plt.ylim(0, plt.ylim()[1])
        plt.ylabel("NASA-TLX weighted score")

    with figure("tlx_components", figsize=fig_size(0.44, 1)):
        components = ["mental", "physical", "temporal", "performance",
                      "effort", "frustration"]
        molten = pd.melt(tlx, id_vars=["user", "experiment", "order"],
                         value_vars=components,
                         var_name="component", value_name="score")
        g = sns.barplot(x=r"component", y="score", hue="experiment",
                        data=molten)

        plt.gca().set_xticklabels(
                ["MD", "PD", "TD", "P", "E", "F"])

        plt.xlabel("NASA-TLX component")
        plt.ylabel("score")

    with figure("survey_results", fig_size(0.44, 1)):
        sns.factorplot(x="experiment", y="total", data=surveys, kind="box")
        sns.swarmplot(x="experiment", y=r"total", data=surveys, palette=cmap_complement, split=True)
        plt.ylim(0, plt.ylim()[1])
        plt.ylabel("survey score")

    with figure("survey_components", figsize=fig_size(0.9, 0.5)):
        molten = pd.melt(surveys, id_vars=["user", "experiment", "order"],
                         value_vars=[r"orientation_understanding",
                                     r"orientation_control",
                                     r"position_understanding",
                                     r"position_control",
                                     r"spacial_understanding",
                                     r"spacial_control"],
                         var_name="question", value_name="rating")
        g = sns.barplot(x=r"rating", y=r"question", hue="experiment",
                        data=molten)
        sns.stripplot(x="rating", y=r"question", data=molten, hue="experiment",
                      split=True, palette=cmap_complement, jitter=0.6, size=3)

        plt.gca().set_yticklabels(
                ["angle aware", "angle control",
                 "position aware", "position control",
                 "rel. pos. aware", "rel. pos. control"])

        handles, labels = g.get_legend_handles_labels()
        plt.legend(handles[2:], labels[2:])
        plt.xlabel("rating")
        plt.title("Survey results")
Exemple #2
0
    def stripplot(self, x=None, y=None, hue=None, data=None, *args, **kwargs):
        """
        Draw a strip plot to show the distribution of observations in each \
        categorical bin using bars.
        It is also a good complement to a box or violin plot in cases where \
        you want to show all observations along with some representation of \
        the underlying distribution

        Parameters
        ----------
        x : the name of a variable in data that provides labels for categories

        y : a list of names of variables in data that needs the count

        hue : the name of a variable in data that provides labels for \
            sub-categories in each big category

        data : pandas dataframe

        **kwargs : other arguments in seaborn.barplot

            order, hue_order : lists of strings, optional

            jitter : float, True/1 is special-cased, optional.
                     Amount of jitter (only along the categorical axis) \
                     to apply

            split : bool, optional

            orient : “v” | “h”, optional

            color : matplotlib color, optional

            palette : palette name, list, or dict, optional

            size : float, optional

            edgecolor : matplotlib color, “gray” is special-cased, optional

            linewidth : float, optional

        Returns
        -------
        figure : matplotlib figure with multiple axes

        References
        ----------
        Seaborn stripplot further documentation
        https://seaborn.pydata.org/generated/seaborn.stripplot.html
        """
        # check data
        if not isinstance(data, (pd.DataFrame)):
            raise ValueError('data must be pandas dataframe')

        # check x and hue
        if x is not None:
            if x not in data.columns.values:
                raise ValueError('{} is NOT in data'.format(x))
        if hue is not None:
            if hue not in data.columns.values:
                raise ValueError('{} is NOT in data'.format(hue))

        # handle single string
        if not isinstance(y, (list, tuple, np.ndarray, pd.Index)):
            y = [y]

        # create fig and axes
        nrows = len(y)
        plt.close()
        fig, axes = plt.subplots(nrows=nrows,
                                 ncols=1,
                                 sharex=self.sharex,
                                 figsize=(self.size[0], nrows * self.size[1]))
        # HACK: handle Axes indexing when only one ax in fig
        if nrows == 1:
            axes = [axes]
        # iterate thru x
        for i, col in enumerate(y):
            # check if col in data
            if col not in data.columns.values:
                raise ValueError('{} is NOT in data'.format(col))
            a = data[col]
            not_nan = np.ones(a.shape[0], dtype=np.bool)
            if np.logical_not(np.isfinite(a)).any():
                logger.warning('RUNTIME WARNING: {} column has inf or nan '
                               ''.format(col))
                a = a.replace([-np.inf, np.inf], np.nan)
                # filter
                not_nan = np.logical_not(a.isnull())
            # plot
            sns.stripplot(x=x,
                          y=col,
                          hue=hue,
                          data=data[not_nan],
                          ax=axes[i],
                          *args,
                          **kwargs)
            if x is not None:
                axes[i].set_title(
                    label='Stripplot Plot of {} With Respect To {} '
                    ''.format(col, x),
                    fontsize=self.title_fontsize)
                axes[i].set_xlabel(xlabel=x, fontsize=self.label_fontsize)
                axes[i].set_ylabel(ylabel=col, fontsize=self.label_fontsize)
            else:  # x is None
                axes[i].set_title(label='Stripplot Plot of {}'.format(col),
                                  fontsize=self.title_fontsize)
                axes[i].set_xlabel(xlabel=col, fontsize=self.label_fontsize)
                axes[i].set_ylabel(ylabel='value',
                                   fontsize=self.label_fontsize)
            axes[i].tick_params(axis='both',
                                which='maj',
                                labelsize=self.tick_fontsize)
            axes[i].legend(loc='lower right')
            fig.subplots_adjust(wspace=0.5,
                                hspace=0.3,
                                left=0.125,
                                right=0.9,
                                top=0.9,
                                bottom=0.1)
            fig.tight_layout()
        plt.show()
        return axes
Exemple #3
0
def pairedcontrast(data,
                   x,
                   y,
                   idcol,
                   reps=3000,
                   statfunction=None,
                   idx=None,
                   figsize=None,
                   beforeAfterSpacer=0.01,
                   violinWidth=0.005,
                   floatOffset=0.05,
                   showRawData=False,
                   showAllYAxes=False,
                   floatContrast=True,
                   smoothboot=False,
                   floatViolinOffset=None,
                   showConnections=True,
                   summaryBar=False,
                   contrastYlim=None,
                   swarmYlim=None,
                   barWidth=0.005,
                   rawMarkerSize=8,
                   rawMarkerType='o',
                   summaryMarkerSize=10,
                   summaryMarkerType='o',
                   summaryBarColor='grey',
                   meansSummaryLineStyle='solid',
                   contrastZeroLineStyle='solid',
                   contrastEffectSizeLineStyle='solid',
                   contrastZeroLineColor='black',
                   contrastEffectSizeLineColor='black',
                   pal=None,
                   legendLoc=2,
                   legendFontSize=12,
                   legendMarkerScale=1,
                   axis_title_size=None,
                   yticksize=None,
                   xticksize=None,
                   tickAngle=45,
                   tickAlignment='right',
                   **kwargs):

    # Preliminaries.
    data = data.dropna()

    # plot params
    if axis_title_size is None:
        axis_title_size = 15
    if yticksize is None:
        yticksize = 12
    if xticksize is None:
        xticksize = 12

    axisTitleParams = {'labelsize': axis_title_size}
    xtickParams = {'labelsize': xticksize}
    ytickParams = {'labelsize': yticksize}

    rc('axes', **axisTitleParams)
    rc('xtick', **xtickParams)
    rc('ytick', **ytickParams)

    ## If `idx` is not specified, just take the FIRST TWO levels alphabetically.
    if idx is None:
        idx = tuple(np.unique(data[x])[0:2], )
    else:
        # check if multi-plot or not
        if all(isinstance(element, str) for element in idx):
            # if idx is supplied but not a multiplot (ie single list or tuple)
            if len(idx) != 2:
                print(idx, "does not have length 2.")
                sys.exit(0)
            else:
                idx = (tuple(idx, ), )
        elif all(isinstance(element, tuple) for element in idx):
            # if idx is supplied, and it is a list/tuple of tuples or lists, we have a multiplot!
            if (any(len(element) != 2 for element in idx)):
                # If any of the tuples contain more than 2 elements.
                print(element, "does not have length 2.")
                sys.exit(0)
    if floatViolinOffset is None:
        floatViolinOffset = beforeAfterSpacer / 2
    if contrastYlim is not None:
        contrastYlim = np.array([contrastYlim[0], contrastYlim[1]])
    if swarmYlim is not None:
        swarmYlim = np.array([swarmYlim[0], swarmYlim[1]])

    ## Here we define the palette on all the levels of the 'x' column.
    ## Thus, if the same pandas dataframe is re-used across different plots,
    ## the color identity of each group will be maintained.
    ## Set palette based on total number of categories in data['x'] or data['hue_column']
    if 'hue' in kwargs:
        u = kwargs['hue']
    else:
        u = x
    if ('color' not in kwargs and 'hue' not in kwargs):
        kwargs['color'] = 'k'

    if pal is None:
        pal = dict(
            zip(data[u].unique(),
                sns.color_palette(n_colors=len(data[u].unique()))))
    else:
        pal = pal

    # Initialise figure.
    if figsize is None:
        if len(idx) > 2:
            figsize = (12, (12 / np.sqrt(2)))
        else:
            figsize = (6, 6)
    fig = plt.figure(figsize=figsize)

    # Initialise GridSpec based on `levs_tuple` shape.
    gsMain = gridspec.GridSpec(
        1,
        np.shape(idx)[0])  # 1 row; columns based on number of tuples in tuple.
    # Set default statfunction
    if statfunction is None:
        statfunction = np.mean
    # Create list to collect all the contrast DataFrames generated.
    contrastList = list()
    contrastListNames = list()

    for gsIdx, xlevs in enumerate(idx):
        ## Pivot tempdat to get before and after lines.
        data_pivot = data.pivot_table(index=idcol, columns=x, values=y)

        # Start plotting!!
        if floatContrast is True:
            ax_raw = fig.add_subplot(gsMain[gsIdx], frame_on=False)
            ax_contrast = ax_raw.twinx()
        else:
            gsSubGridSpec = gridspec.GridSpecFromSubplotSpec(
                2, 1, subplot_spec=gsMain[gsIdx])
            ax_raw = plt.Subplot(fig, gsSubGridSpec[0, 0], frame_on=False)
            ax_contrast = plt.Subplot(fig,
                                      gsSubGridSpec[1, 0],
                                      sharex=ax_raw,
                                      frame_on=False)

        ## Plot raw data as swarmplot or stripplot.
        if showRawData is True:
            swarm_raw = sns.swarmplot(data=data,
                                      x=x,
                                      y=y,
                                      order=xlevs,
                                      ax=ax_raw,
                                      palette=pal,
                                      size=rawMarkerSize,
                                      marker=rawMarkerType,
                                      **kwargs)
        else:
            swarm_raw = sns.stripplot(data=data,
                                      x=x,
                                      y=y,
                                      order=xlevs,
                                      ax=ax_raw,
                                      palette=pal,
                                      **kwargs)
        swarm_raw.set_ylim(swarmYlim)

        ## Get some details about the raw data.
        maxXBefore = max(swarm_raw.collections[0].get_offsets().T[0])
        minXAfter = min(swarm_raw.collections[1].get_offsets().T[0])
        if showRawData is True:
            #beforeAfterSpacer = (getSwarmSpan(swarm_raw, 0) + getSwarmSpan(swarm_raw, 1))/2
            beforeAfterSpacer = 1
        xposAfter = maxXBefore + beforeAfterSpacer
        xAfterShift = minXAfter - xposAfter

        ## shift the after swarmpoints closer for aesthetic purposes.
        offsetSwarmX(swarm_raw.collections[1], -xAfterShift)

        ## pandas DataFrame of 'before' group
        x1 = pd.DataFrame({
            str(xlevs[0] + '_x'):
            pd.Series(swarm_raw.collections[0].get_offsets().T[0]),
            xlevs[0]:
            pd.Series(swarm_raw.collections[0].get_offsets().T[1]),
            '_R_':
            pd.Series(swarm_raw.collections[0].get_facecolors().T[0]),
            '_G_':
            pd.Series(swarm_raw.collections[0].get_facecolors().T[1]),
            '_B_':
            pd.Series(swarm_raw.collections[0].get_facecolors().T[2]),
        })
        ## join the RGB columns into a tuple, then assign to a column.
        x1['_hue_'] = x1[['_R_', '_G_', '_B_']].apply(tuple, axis=1)
        x1 = x1.sort_values(by=xlevs[0])
        x1.index = data_pivot.sort_values(by=xlevs[0]).index

        ## pandas DataFrame of 'after' group
        ### create convenient signifiers for column names.
        befX = str(xlevs[0] + '_x')
        aftX = str(xlevs[1] + '_x')

        x2 = pd.DataFrame({
            aftX:
            pd.Series(swarm_raw.collections[1].get_offsets().T[0]),
            xlevs[1]:
            pd.Series(swarm_raw.collections[1].get_offsets().T[1])
        })
        x2 = x2.sort_values(by=xlevs[1])
        x2.index = data_pivot.sort_values(by=xlevs[1]).index

        ## Join x1 and x2, on both their indexes.
        plotPoints = x1.merge(x2,
                              left_index=True,
                              right_index=True,
                              how='outer')

        ## Add the hue column if hue argument was passed.
        if 'hue' in kwargs:
            h = kwargs['hue']
            plotPoints[h] = data.pivot(index=idcol, columns=x,
                                       values=h)[xlevs[0]]
            swarm_raw.legend(loc=legendLoc,
                             fontsize=legendFontSize,
                             markerscale=legendMarkerScale)

        ## Plot the lines to join the 'before' points to their respective 'after' points.
        if showConnections is True:
            for i in plotPoints.index:
                ax_raw.plot(
                    [plotPoints.ix[i, befX], plotPoints.ix[i, aftX]],
                    [plotPoints.ix[i, xlevs[0]], plotPoints.ix[i, xlevs[1]]],
                    linestyle='solid',
                    color=plotPoints.ix[i, '_hue_'],
                    linewidth=0.75,
                    alpha=0.75)

        ## Hide the raw swarmplot data if so desired.
        if showRawData is False:
            swarm_raw.collections[0].set_visible(False)
            swarm_raw.collections[1].set_visible(False)

        if showRawData is True:
            #maxSwarmSpan = max(np.array([getSwarmSpan(swarm_raw, 0), getSwarmSpan(swarm_raw, 1)]))/2
            maxSwarmSpan = 0.5
        else:
            maxSwarmSpan = barWidth

        ## Plot Summary Bar.
        if summaryBar is True:
            # Calculate means
            means = data.groupby([x], sort=True).mean()[y]
            # # Calculate medians
            # medians = data.groupby([x], sort = True).median()[y]

            ## Draw summary bar.
            bar_raw = sns.barplot(x=means.index,
                                  y=means.values,
                                  order=xlevs,
                                  ax=ax_raw,
                                  ci=0,
                                  facecolor=summaryBarColor,
                                  alpha=0.25)
            ## Draw zero reference line.
            ax_raw.add_artist(
                Line2D((ax_raw.xaxis.get_view_interval()[0],
                        ax_raw.xaxis.get_view_interval()[1]), (0, 0),
                       color='black',
                       linewidth=0.75))

            ## get swarm with largest span, set as max width of each barplot.
            for i, bar in enumerate(bar_raw.patches):
                x_width = bar.get_x()
                width = bar.get_width()
                centre = x_width + width / 2.
                if i == 0:
                    bar.set_x(centre - maxSwarmSpan / 2.)
                else:
                    bar.set_x(centre - xAfterShift - maxSwarmSpan / 2.)
                bar.set_width(maxSwarmSpan)

        # Get y-limits of the treatment swarm points.
        beforeRaw = pd.DataFrame(swarm_raw.collections[0].get_offsets())
        afterRaw = pd.DataFrame(swarm_raw.collections[1].get_offsets())
        before_leftx = min(beforeRaw[0])
        after_leftx = min(afterRaw[0])
        after_rightx = max(afterRaw[0])
        after_stat_summary = statfunction(beforeRaw[1])

        # Calculate the summary difference and CI.
        plotPoints['delta_y'] = plotPoints[xlevs[1]] - plotPoints[xlevs[0]]
        plotPoints['delta_x'] = [0] * np.shape(plotPoints)[0]

        tempseries = plotPoints['delta_y'].tolist()
        test = tempseries.count(tempseries[0]) != len(tempseries)

        bootsDelta = bootstrap(plotPoints['delta_y'],
                               statfunction=statfunction,
                               smoothboot=smoothboot,
                               reps=reps)
        summDelta = bootsDelta['summary']
        lowDelta = bootsDelta['bca_ci_low']
        highDelta = bootsDelta['bca_ci_high']

        # set new xpos for delta violin.
        if floatContrast is True:
            if showRawData is False:
                xposPlusViolin = deltaSwarmX = after_rightx + floatViolinOffset
            else:
                xposPlusViolin = deltaSwarmX = after_rightx + maxSwarmSpan
        else:
            xposPlusViolin = xposAfter
        if showRawData is True:
            # If showRawData is True and floatContrast is True,
            # set violinwidth to the barwidth.
            violinWidth = maxSwarmSpan

        xmaxPlot = xposPlusViolin + violinWidth

        # Plot the summary measure.
        ax_contrast.plot(xposPlusViolin,
                         summDelta,
                         marker='o',
                         markerfacecolor='k',
                         markersize=summaryMarkerSize,
                         alpha=0.75)

        # Plot the CI.
        ax_contrast.plot([xposPlusViolin, xposPlusViolin],
                         [lowDelta, highDelta],
                         color='k',
                         alpha=0.75,
                         linestyle='solid')

        # Plot the violin-plot.
        v = ax_contrast.violinplot(bootsDelta['stat_array'], [xposPlusViolin],
                                   widths=violinWidth,
                                   showextrema=False,
                                   showmeans=False)
        halfviolin(v, half='right', color='k')

        # Remove left axes x-axis title.
        ax_raw.set_xlabel("")
        # Remove floating axes y-axis title.
        ax_contrast.set_ylabel("")

        # Set proper x-limits
        ax_raw.set_xlim(before_leftx - beforeAfterSpacer / 2, xmaxPlot)
        ax_raw.get_xaxis().set_view_interval(
            before_leftx - beforeAfterSpacer / 2,
            after_rightx + beforeAfterSpacer / 2)
        ax_contrast.set_xlim(ax_raw.get_xlim())

        if floatContrast is True:
            # Set the ticks locations for ax_raw.
            ax_raw.get_xaxis().set_ticks((0, xposAfter))

            # Make sure they have the same y-limits.
            ax_contrast.set_ylim(ax_raw.get_ylim())

            # Drawing in the x-axis for ax_raw.
            ## Set the tick labels!
            ax_raw.set_xticklabels(xlevs,
                                   rotation=tickAngle,
                                   horizontalalignment=tickAlignment)
            ## Get lowest y-value for ax_raw.
            y = ax_raw.get_yaxis().get_view_interval()[0]

            # Align the left axes and the floating axes.
            align_yaxis(ax_raw, statfunction(plotPoints[xlevs[0]]),
                        ax_contrast, 0)

            # Add label to floating axes. But on ax_raw!
            ax_raw.text(x=deltaSwarmX,
                        y=ax_raw.get_yaxis().get_view_interval()[0],
                        horizontalalignment='left',
                        s='Difference',
                        fontsize=15)

            # Set reference lines
            ## zero line
            ax_contrast.hlines(
                0,  # y-coordinate
                ax_contrast.xaxis.get_majorticklocs()
                [0],  # x-coordinates, start and end.
                ax_raw.xaxis.get_view_interval()[1],
                linestyle='solid',
                linewidth=0.75,
                color='black')

            ## effect size line
            ax_contrast.hlines(summDelta,
                               ax_contrast.xaxis.get_majorticklocs()[1],
                               ax_raw.xaxis.get_view_interval()[1],
                               linestyle='solid',
                               linewidth=0.75,
                               color='black')

            # Align the left axes and the floating axes.
            align_yaxis(ax_raw, after_stat_summary, ax_contrast, 0.)
        else:
            # Set the ticks locations for ax_raw.
            ax_raw.get_xaxis().set_ticks((0, xposAfter))

            fig.add_subplot(ax_raw)
            fig.add_subplot(ax_contrast)
        ax_contrast.set_ylim(contrastYlim)
        # Calculate p-values.
        # 1-sample t-test to see if the mean of the difference is different from 0.
        ttestresult = ttest_1samp(plotPoints['delta_y'], popmean=0)[1]
        bootsDelta['ttest_pval'] = ttestresult
        contrastList.append(bootsDelta)
        contrastListNames.append(str(xlevs[1]) + ' v.s. ' + str(xlevs[0]))

    # Turn contrastList into a pandas DataFrame,
    contrastList = pd.DataFrame(contrastList).T
    contrastList.columns = contrastListNames

    # Now we iterate thru the contrast axes to normalize all the ylims.
    for j, i in enumerate(range(1, len(fig.get_axes()), 2)):
        axx = fig.get_axes()[i]
        ## Get max and min of the dataset.
        lower = np.min(contrastList.ix['stat_array', j])
        upper = np.max(contrastList.ix['stat_array', j])
        meandiff = contrastList.ix['summary', j]

        ## Make sure we have zero in the limits.
        if lower > 0:
            lower = 0.
        if upper < 0:
            upper = 0.

        ## Get tick distance on raw axes.
        ## This will be the tick distance for the contrast axes.
        rawAxesTicks = fig.get_axes()[i - 1].yaxis.get_majorticklocs()
        rawAxesTickDist = rawAxesTicks[1] - rawAxesTicks[0]

        ## First re-draw of axis with new tick interval
        axx.yaxis.set_major_locator(MultipleLocator(rawAxesTickDist))
        newticks1 = fig.get_axes()[i].get_yticks()

        if floatContrast is False:
            if (showAllYAxes is False and i in range(2, len(fig.get_axes()))):
                axx.get_yaxis().set_visible(showAllYAxes)
            else:
                ## Obtain major ticks that comfortably encompass lower and upper.
                newticks2 = list()
                for a, b in enumerate(newticks1):
                    if (b >= lower and b <= upper):
                        # if the tick lies within upper and lower, take it.
                        newticks2.append(b)
                # if the meandiff falls outside of the newticks2 set, add a tick in the right direction.
                if np.max(newticks2) < meandiff:
                    ind = np.where(newticks1 == np.max(newticks2))[0][
                        0]  # find out the max tick index in newticks1.
                    newticks2.append(newticks1[ind + 1])
                elif meandiff < np.min(newticks2):
                    ind = np.where(newticks1 == np.min(newticks2))[0][
                        0]  # find out the min tick index in newticks1.
                    newticks2.append(newticks1[ind - 1])
                newticks2 = np.array(newticks2)
                newticks2.sort()
                axx.yaxis.set_major_locator(FixedLocator(locs=newticks2))

                ## Draw zero reference line.
                axx.hlines(
                    y=0,
                    xmin=fig.get_axes()[i].get_xaxis().get_view_interval()[0],
                    xmax=fig.get_axes()[i].get_xaxis().get_view_interval()[1],
                    linestyle=contrastZeroLineStyle,
                    linewidth=0.75,
                    color=contrastZeroLineColor)

                sns.despine(ax=fig.get_axes()[i],
                            trim=True,
                            bottom=False,
                            right=True,
                            left=False,
                            top=True)

                ## Draw back the lines for the relevant y-axes.
                drawback_y(axx)

                ## Draw back the lines for the relevant x-axes.
                drawback_x(axx)

        elif floatContrast is True:
            ## Get the original ticks on the floating y-axis.
            newticks1 = fig.get_axes()[i].get_yticks()

            ## Obtain major ticks that comfortably encompass lower and upper.
            newticks2 = list()
            for a, b in enumerate(newticks1):
                if (b >= lower and b <= upper):
                    # if the tick lies within upper and lower, take it.
                    newticks2.append(b)
            # if the meandiff falls outside of the newticks2 set, add a tick in the right direction.
            if np.max(newticks2) < meandiff:
                ind = np.where(newticks1 == np.max(newticks2))[0][
                    0]  # find out the max tick index in newticks1.
                newticks2.append(newticks1[ind + 1])
            elif meandiff < np.min(newticks2):
                ind = np.where(newticks1 == np.min(newticks2))[0][
                    0]  # find out the min tick index in newticks1.
                newticks2.append(newticks1[ind - 1])
            newticks2 = np.array(newticks2)
            newticks2.sort()

            ## Re-draw the axis.
            axx.yaxis.set_major_locator(FixedLocator(locs=newticks2))

            ## Despine and trim the axes.
            sns.despine(ax=axx,
                        trim=True,
                        bottom=False,
                        right=False,
                        left=True,
                        top=True)

    for i in range(0, len(fig.get_axes()), 2):
        # Loop through the raw data swarmplots and despine them appropriately.
        if floatContrast is True:
            sns.despine(ax=fig.get_axes()[i], trim=True, right=True)

        else:
            sns.despine(ax=fig.get_axes()[i],
                        trim=True,
                        bottom=True,
                        right=True)
            fig.get_axes()[i].get_xaxis().set_visible(False)

        # Draw back the lines for the relevant y-axes.
        ymin = fig.get_axes()[i].get_yaxis().get_majorticklocs()[0]
        ymax = fig.get_axes()[i].get_yaxis().get_majorticklocs()[-1]
        x, _ = fig.get_axes()[i].get_xaxis().get_view_interval()
        fig.get_axes()[i].add_artist(
            Line2D((x, x), (ymin, ymax), color='black', linewidth=1.5))

    # Zero gaps between plots on the same row, if floatContrast is False
    if (floatContrast is False and showAllYAxes is False):
        gsMain.update(wspace=0)
    else:
        # Tight Layout!
        gsMain.tight_layout(fig)

    # And we're done.
    rcdefaults()  # restore matplotlib defaults.
    sns.set()  # restore seaborn defaults.
    return fig, contrastList
mean_red = [np.mean(red[rt]) for red, rt in zip(reds, redst)]
mean_green = [np.mean(green[gt]) for green, gt in zip(greens, greenst)]

# Put them in a dataframe and save it
df = pd.DataFrame({
    'filename': image_filenames,
    'class': classes,
    'red': mean_red,
    'green': mean_green
})
df.to_excel(os.path.expanduser('~') + '/Desktop/intensities.xlsx')

# Make a bar plot and save it
means = df.groupby('class').aggregate(np.mean)
errs = df.groupby('class').aggregate(np.std)
means.plot(kind='bar', yerr=errs)
plt.xlabel('')
plt.tight_layout()
plt.savefig(os.path.expanduser('~') + '/Desktop/intensities.png',
            dpi=300)

# Make a jitter plot and save it
td = pd.melt(df, id_vars=['filename', 'class'],
             value_vars=['red', 'green'], var_name='channel',
             value_name='intensity').set_index('filename')
ax = sns.stripplot(x='class', y='intensity', hue='channel', data=td,
                   hue_order=('green', 'red'),
                   split=True, jitter=True)
ax.figure.savefig(os.path.expanduser('~') + '/Desktop/jitter.png',
                  dpi=300)
Exemple #5
0
import matplotlib.pyplot as plt
import pandas as pd
import seaborn.apionly as sns

plt.style.use('custom')

failcolor = '#C44E52'
passcolor = '#55A868'
warncolor = '#FFA574'

df = pd.read_csv('data.csv')

fig = plt.figure(figsize=(cm2inch(15), cm2inch(6)))

ax = sns.stripplot(x='generator', y='value', data=df, jitter=True)
ax.set_xlabel('Generador')
ax.set_ylabel('p values')

fig.savefig('summary.pdf')
# for gen,gendf in df.groupby('generator'):
#     fig = plt.figure(figsize=(cm2inch(2),cm2inch(2)))
#     ax = fig.add_subplot(1,1,1)
#     ax.set_ylabel('')
#     ax.set_xlabel('')
#     L = len(gendf['value'])
#     vals = list(gendf['value'])
#     for j in range(L):
#         val = vals[j]
#         color = passcolor
#         marker = 'o'
def pairedcontrast(data, x, y, idcol, reps = 3000,
statfunction = None, idx = None, figsize = None,
beforeAfterSpacer = 0.01, 
violinWidth = 0.005, 
floatOffset = 0.05, 
showRawData = False,
showAllYAxes = False,
floatContrast = True,
smoothboot = False,
floatViolinOffset = None, 
showConnections = True,
summaryBar = False,
contrastYlim = None,
swarmYlim = None,
barWidth = 0.005,
rawMarkerSize = 8,
rawMarkerType = 'o',
summaryMarkerSize = 10,
summaryMarkerType = 'o',
summaryBarColor = 'grey',
meansSummaryLineStyle = 'solid', 
contrastZeroLineStyle = 'solid', contrastEffectSizeLineStyle = 'solid',
contrastZeroLineColor = 'black', contrastEffectSizeLineColor = 'black',
pal = None,
legendLoc = 2, legendFontSize = 12, legendMarkerScale = 1,
axis_title_size = None,
yticksize = None,
xticksize = None,
tickAngle=45,
tickAlignment='right',
**kwargs):

    # Preliminaries.
    data = data.dropna()

    # plot params
    if axis_title_size is None:
        axis_title_size = 15
    if yticksize is None:
        yticksize = 12
    if xticksize is None:
        xticksize = 12

    axisTitleParams = {'labelsize' : axis_title_size}
    xtickParams = {'labelsize' : xticksize}
    ytickParams = {'labelsize' : yticksize}

    rc('axes', **axisTitleParams)
    rc('xtick', **xtickParams)
    rc('ytick', **ytickParams)

    ## If `idx` is not specified, just take the FIRST TWO levels alphabetically.
    if idx is None:
        idx = tuple(np.unique(data[x])[0:2],)
    else:
        # check if multi-plot or not
        if all(isinstance(element, str) for element in idx):
            # if idx is supplied but not a multiplot (ie single list or tuple)
            if len(idx) != 2:
                print(idx, "does not have length 2.")
                sys.exit(0)
            else:
                idx = (tuple(idx, ),)
        elif all(isinstance(element, tuple) for element in idx):
            # if idx is supplied, and it is a list/tuple of tuples or lists, we have a multiplot!
            if ( any(len(element) != 2 for element in idx) ):
                # If any of the tuples contain more than 2 elements.
                print(element, "does not have length 2.")
                sys.exit(0)
    if floatViolinOffset is None:
        floatViolinOffset = beforeAfterSpacer/2
    if contrastYlim is not None:
        contrastYlim = np.array([contrastYlim[0],contrastYlim[1]])
    if swarmYlim is not None:
        swarmYlim = np.array([swarmYlim[0],swarmYlim[1]])

    ## Here we define the palette on all the levels of the 'x' column.
    ## Thus, if the same pandas dataframe is re-used across different plots,
    ## the color identity of each group will be maintained.
    ## Set palette based on total number of categories in data['x'] or data['hue_column']
    if 'hue' in kwargs:
        u = kwargs['hue']
    else:
        u = x
    if ('color' not in kwargs and 'hue' not in kwargs):
        kwargs['color'] = 'k'

    if pal is None:
        pal = dict( zip( data[u].unique(), sns.color_palette(n_colors = len(data[u].unique())) ) 
                      )
    else:
        pal = pal

    # Initialise figure.
    if figsize is None:
        if len(idx) > 2:
            figsize = (12,(12/np.sqrt(2)))
        else:
            figsize = (6,6)
    fig = plt.figure(figsize = figsize)

    # Initialise GridSpec based on `levs_tuple` shape.
    gsMain = gridspec.GridSpec( 1, np.shape(idx)[0]) # 1 row; columns based on number of tuples in tuple.
    # Set default statfunction
    if statfunction is None:
        statfunction = np.mean
    # Create list to collect all the contrast DataFrames generated.
    contrastList = list()
    contrastListNames = list()

    for gsIdx, xlevs in enumerate(idx):
        ## Pivot tempdat to get before and after lines.
        data_pivot = data.pivot_table(index = idcol, columns = x, values = y)

        # Start plotting!!
        if floatContrast is True:
            ax_raw = fig.add_subplot(gsMain[gsIdx], frame_on = False)
            ax_contrast = ax_raw.twinx()
        else:
            gsSubGridSpec = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec = gsMain[gsIdx])
            ax_raw = plt.Subplot(fig, gsSubGridSpec[0, 0], frame_on = False)
            ax_contrast = plt.Subplot(fig, gsSubGridSpec[1, 0], sharex = ax_raw, frame_on = False)

        ## Plot raw data as swarmplot or stripplot.
        if showRawData is True:
            swarm_raw = sns.swarmplot(data = data, 
                                     x = x, y = y, 
                                     order = xlevs,
                                     ax = ax_raw,
                                     palette = pal,
                                     size = rawMarkerSize,
                                     marker = rawMarkerType,
                                     **kwargs)
        else:
            swarm_raw = sns.stripplot(data = data, 
                                     x = x, y = y, 
                                     order = xlevs,
                                     ax = ax_raw,
                                     palette = pal,
                                     **kwargs)
        swarm_raw.set_ylim(swarmYlim)
           
        ## Get some details about the raw data.
        maxXBefore = max(swarm_raw.collections[0].get_offsets().T[0])
        minXAfter = min(swarm_raw.collections[1].get_offsets().T[0])
        if showRawData is True:
            #beforeAfterSpacer = (getSwarmSpan(swarm_raw, 0) + getSwarmSpan(swarm_raw, 1))/2
            beforeAfterSpacer = 1
        xposAfter = maxXBefore + beforeAfterSpacer
        xAfterShift = minXAfter - xposAfter

        ## shift the after swarmpoints closer for aesthetic purposes.
        offsetSwarmX(swarm_raw.collections[1], -xAfterShift)

        ## pandas DataFrame of 'before' group
        x1 = pd.DataFrame({str(xlevs[0] + '_x') : pd.Series(swarm_raw.collections[0].get_offsets().T[0]),
                       xlevs[0] : pd.Series(swarm_raw.collections[0].get_offsets().T[1]),
                       '_R_' : pd.Series(swarm_raw.collections[0].get_facecolors().T[0]),
                       '_G_' : pd.Series(swarm_raw.collections[0].get_facecolors().T[1]),
                       '_B_' : pd.Series(swarm_raw.collections[0].get_facecolors().T[2]),
                      })
        ## join the RGB columns into a tuple, then assign to a column.
        x1['_hue_'] = x1[['_R_', '_G_', '_B_']].apply(tuple, axis=1) 
        x1 = x1.sort_values(by = xlevs[0])
        x1.index = data_pivot.sort_values(by = xlevs[0]).index

        ## pandas DataFrame of 'after' group
        ### create convenient signifiers for column names.
        befX = str(xlevs[0] + '_x')
        aftX = str(xlevs[1] + '_x')

        x2 = pd.DataFrame( {aftX : pd.Series(swarm_raw.collections[1].get_offsets().T[0]),
            xlevs[1] : pd.Series(swarm_raw.collections[1].get_offsets().T[1])} )
        x2 = x2.sort_values(by = xlevs[1])
        x2.index = data_pivot.sort_values(by = xlevs[1]).index

        ## Join x1 and x2, on both their indexes.
        plotPoints = x1.merge(x2, left_index = True, right_index = True, how='outer')

        ## Add the hue column if hue argument was passed.
        if 'hue' in kwargs:
            h = kwargs['hue']
            plotPoints[h] = data.pivot(index = idcol, columns = x, values = h)[xlevs[0]]
            swarm_raw.legend(loc = legendLoc, 
                fontsize = legendFontSize, 
                markerscale = legendMarkerScale)

        ## Plot the lines to join the 'before' points to their respective 'after' points.
        if showConnections is True:
            for i in plotPoints.index:
                ax_raw.plot([ plotPoints.ix[i, befX],
                    plotPoints.ix[i, aftX] ],
                    [ plotPoints.ix[i, xlevs[0]], 
                    plotPoints.ix[i, xlevs[1]] ],
                    linestyle = 'solid',
                    color = plotPoints.ix[i, '_hue_'],
                    linewidth = 0.75,
                    alpha = 0.75
                    )

        ## Hide the raw swarmplot data if so desired.
        if showRawData is False:
            swarm_raw.collections[0].set_visible(False)
            swarm_raw.collections[1].set_visible(False)

        if showRawData is True:
            #maxSwarmSpan = max(np.array([getSwarmSpan(swarm_raw, 0), getSwarmSpan(swarm_raw, 1)]))/2
            maxSwarmSpan = 0.5
        else:
            maxSwarmSpan = barWidth            

        ## Plot Summary Bar.
        if summaryBar is True:
            # Calculate means
            means = data.groupby([x], sort = True).mean()[y]
            # # Calculate medians
            # medians = data.groupby([x], sort = True).median()[y]

            ## Draw summary bar.
            bar_raw = sns.barplot(x = means.index, 
                        y = means.values, 
                        order = xlevs,
                        ax = ax_raw,
                        ci = 0,
                        facecolor = summaryBarColor, 
                        alpha = 0.25)
            ## Draw zero reference line.
            ax_raw.add_artist(Line2D(
                (ax_raw.xaxis.get_view_interval()[0], 
                    ax_raw.xaxis.get_view_interval()[1]), 
                (0,0),
                color='black', linewidth=0.75
                )
            )       

            ## get swarm with largest span, set as max width of each barplot.
            for i, bar in enumerate(bar_raw.patches):
                x_width = bar.get_x()
                width = bar.get_width()
                centre = x_width + width/2.
                if i == 0:
                    bar.set_x(centre - maxSwarmSpan/2.)
                else:
                    bar.set_x(centre - xAfterShift - maxSwarmSpan/2.)
                bar.set_width(maxSwarmSpan)

        # Get y-limits of the treatment swarm points.
        beforeRaw = pd.DataFrame( swarm_raw.collections[0].get_offsets() )
        afterRaw = pd.DataFrame( swarm_raw.collections[1].get_offsets() )
        before_leftx = min(beforeRaw[0])
        after_leftx = min(afterRaw[0])
        after_rightx = max(afterRaw[0])
        after_stat_summary = statfunction(beforeRaw[1])

        # Calculate the summary difference and CI.
        plotPoints['delta_y'] = plotPoints[xlevs[1]] - plotPoints[xlevs[0]]
        plotPoints['delta_x'] = [0] * np.shape(plotPoints)[0]

        tempseries = plotPoints['delta_y'].tolist()
        test = tempseries.count(tempseries[0]) != len(tempseries)

        bootsDelta = bootstrap(plotPoints['delta_y'],
            statfunction = statfunction, 
            smoothboot = smoothboot,
            reps = reps)
        summDelta = bootsDelta['summary']
        lowDelta = bootsDelta['bca_ci_low']
        highDelta = bootsDelta['bca_ci_high']

        # set new xpos for delta violin.
        if floatContrast is True:
            if showRawData is False:
                xposPlusViolin = deltaSwarmX = after_rightx + floatViolinOffset
            else:
                xposPlusViolin = deltaSwarmX = after_rightx + maxSwarmSpan
        else:
            xposPlusViolin = xposAfter
        if showRawData is True:
            # If showRawData is True and floatContrast is True, 
            # set violinwidth to the barwidth.
            violinWidth = maxSwarmSpan

        xmaxPlot = xposPlusViolin + violinWidth

        # Plot the summary measure.
        ax_contrast.plot(xposPlusViolin, summDelta,
            marker = 'o',
            markerfacecolor = 'k', 
            markersize = summaryMarkerSize,
            alpha = 0.75
            )

        # Plot the CI.
        ax_contrast.plot([xposPlusViolin, xposPlusViolin],
            [lowDelta, highDelta],
            color = 'k', 
            alpha = 0.75,
            linestyle = 'solid'
            )

        # Plot the violin-plot.
        v = ax_contrast.violinplot(bootsDelta['stat_array'], [xposPlusViolin], 
                                 widths = violinWidth, 
                                 showextrema = False, 
                                 showmeans = False)
        halfviolin(v, half = 'right', color = 'k')

        # Remove left axes x-axis title.
        ax_raw.set_xlabel("")
        # Remove floating axes y-axis title.
        ax_contrast.set_ylabel("")

        # Set proper x-limits
        ax_raw.set_xlim(before_leftx - beforeAfterSpacer/2, xmaxPlot)
        ax_raw.get_xaxis().set_view_interval(before_leftx - beforeAfterSpacer/2, 
            after_rightx + beforeAfterSpacer/2)
        ax_contrast.set_xlim(ax_raw.get_xlim())

        if floatContrast is True:
            # Set the ticks locations for ax_raw.
            ax_raw.get_xaxis().set_ticks((0, xposAfter))

            # Make sure they have the same y-limits.
            ax_contrast.set_ylim(ax_raw.get_ylim())
            
            # Drawing in the x-axis for ax_raw.
            ## Set the tick labels!
            ax_raw.set_xticklabels(xlevs, rotation = tickAngle, horizontalalignment = tickAlignment)
            ## Get lowest y-value for ax_raw.
            y = ax_raw.get_yaxis().get_view_interval()[0] 

            # Align the left axes and the floating axes.
            align_yaxis(ax_raw, statfunction(plotPoints[xlevs[0]]),
                           ax_contrast, 0)

            # Add label to floating axes. But on ax_raw!
            ax_raw.text(x = deltaSwarmX,
                          y = ax_raw.get_yaxis().get_view_interval()[0],
                          horizontalalignment = 'left',
                          s = 'Difference',
                          fontsize = 15)        

            # Set reference lines
            ## zero line
            ax_contrast.hlines(0,                                           # y-coordinate
                            ax_contrast.xaxis.get_majorticklocs()[0],       # x-coordinates, start and end.
                            ax_raw.xaxis.get_view_interval()[1],   
                            linestyle = 'solid',
                            linewidth = 0.75,
                            color = 'black')

            ## effect size line
            ax_contrast.hlines(summDelta, 
                            ax_contrast.xaxis.get_majorticklocs()[1],
                            ax_raw.xaxis.get_view_interval()[1],
                            linestyle = 'solid',
                            linewidth = 0.75,
                            color = 'black')

            # Align the left axes and the floating axes.
            align_yaxis(ax_raw, after_stat_summary, ax_contrast, 0.)
        else:
            # Set the ticks locations for ax_raw.
            ax_raw.get_xaxis().set_ticks((0, xposAfter))
            
            fig.add_subplot(ax_raw)
            fig.add_subplot(ax_contrast)
        ax_contrast.set_ylim(contrastYlim)
        # Calculate p-values.
        # 1-sample t-test to see if the mean of the difference is different from 0.
        ttestresult = ttest_1samp(plotPoints['delta_y'], popmean = 0)[1]
        bootsDelta['ttest_pval'] = ttestresult
        contrastList.append(bootsDelta)
        contrastListNames.append( str(xlevs[1])+' v.s. '+str(xlevs[0]) )

    # Turn contrastList into a pandas DataFrame,
    contrastList = pd.DataFrame(contrastList).T
    contrastList.columns = contrastListNames

    # Now we iterate thru the contrast axes to normalize all the ylims.
    for j,i in enumerate(range(1, len(fig.get_axes()), 2)):
        axx=fig.get_axes()[i]
        ## Get max and min of the dataset.
        lower = np.min(contrastList.ix['stat_array',j])
        upper = np.max(contrastList.ix['stat_array',j])
        meandiff = contrastList.ix['summary', j]

        ## Make sure we have zero in the limits.
        if lower > 0:
            lower = 0.
        if upper < 0:
            upper = 0.

        ## Get tick distance on raw axes.
        ## This will be the tick distance for the contrast axes.
        rawAxesTicks = fig.get_axes()[i-1].yaxis.get_majorticklocs()
        rawAxesTickDist = rawAxesTicks[1] - rawAxesTicks[0]

        ## First re-draw of axis with new tick interval
        axx.yaxis.set_major_locator(MultipleLocator(rawAxesTickDist))
        newticks1 = fig.get_axes()[i].get_yticks()

        if floatContrast is False:
            if (showAllYAxes is False and i in range( 2, len(fig.get_axes())) ):
                axx.get_yaxis().set_visible(showAllYAxes)
            else:
                ## Obtain major ticks that comfortably encompass lower and upper.
                newticks2 = list()
                for a,b in enumerate(newticks1):
                    if (b >= lower and b <= upper):
                        # if the tick lies within upper and lower, take it.
                        newticks2.append(b)
                # if the meandiff falls outside of the newticks2 set, add a tick in the right direction.
                if np.max(newticks2) < meandiff:
                    ind = np.where(newticks1 == np.max(newticks2))[0][0] # find out the max tick index in newticks1.
                    newticks2.append( newticks1[ind+1] )
                elif meandiff < np.min(newticks2):
                    ind = np.where(newticks1 == np.min(newticks2))[0][0] # find out the min tick index in newticks1.
                    newticks2.append( newticks1[ind-1] )
                newticks2 = np.array(newticks2)
                newticks2.sort()
                axx.yaxis.set_major_locator(FixedLocator(locs = newticks2))

                ## Draw zero reference line.
                axx.hlines(y = 0,
                    xmin = fig.get_axes()[i].get_xaxis().get_view_interval()[0], 
                    xmax = fig.get_axes()[i].get_xaxis().get_view_interval()[1],
                    linestyle = contrastZeroLineStyle,
                    linewidth = 0.75,
                    color = contrastZeroLineColor)

                sns.despine(ax = fig.get_axes()[i], trim = True, 
                    bottom = False, right = True,
                    left = False, top = True)

                ## Draw back the lines for the relevant y-axes.
                drawback_y(axx)

                ## Draw back the lines for the relevant x-axes.
                drawback_x(axx)

        elif floatContrast is True:
            ## Get the original ticks on the floating y-axis.
            newticks1 = fig.get_axes()[i].get_yticks()

            ## Obtain major ticks that comfortably encompass lower and upper.
            newticks2 = list()
            for a,b in enumerate(newticks1):
                if (b >= lower and b <= upper):
                    # if the tick lies within upper and lower, take it.
                    newticks2.append(b)
            # if the meandiff falls outside of the newticks2 set, add a tick in the right direction.
            if np.max(newticks2) < meandiff:
                ind = np.where(newticks1 == np.max(newticks2))[0][0] # find out the max tick index in newticks1.
                newticks2.append( newticks1[ind+1] )
            elif meandiff < np.min(newticks2):
                ind = np.where(newticks1 == np.min(newticks2))[0][0] # find out the min tick index in newticks1.
                newticks2.append( newticks1[ind-1] )
            newticks2 = np.array(newticks2)
            newticks2.sort()

            ## Re-draw the axis.
            axx.yaxis.set_major_locator(FixedLocator(locs = newticks2)) 

            ## Despine and trim the axes.
            sns.despine(ax = axx, trim = True, 
                bottom = False, right = False,
                left = True, top = True)

    for i in range(0, len(fig.get_axes()), 2):
        # Loop through the raw data swarmplots and despine them appropriately.
        if floatContrast is True:
            sns.despine(ax = fig.get_axes()[i], trim = True, right = True)

        else:
            sns.despine(ax = fig.get_axes()[i], trim = True, bottom = True, right = True)
            fig.get_axes()[i].get_xaxis().set_visible(False)

        # Draw back the lines for the relevant y-axes.
        ymin = fig.get_axes()[i].get_yaxis().get_majorticklocs()[0]
        ymax = fig.get_axes()[i].get_yaxis().get_majorticklocs()[-1]
        x, _ = fig.get_axes()[i].get_xaxis().get_view_interval()
        fig.get_axes()[i].add_artist(Line2D((x, x), (ymin, ymax), color='black', linewidth=1.5))    

    # Zero gaps between plots on the same row, if floatContrast is False
    if (floatContrast is False and showAllYAxes is False):
        gsMain.update(wspace = 0)
    else:    
        # Tight Layout!
        gsMain.tight_layout(fig)

    # And we're done.
    rcdefaults() # restore matplotlib defaults.
    sns.set() # restore seaborn defaults.
    return fig, contrastList
Exemple #7
0
def plot_results(transformation):
    res_dir = '../results'

    _, dir_sigmas, _ = next(os.walk(res_dir))
    dir_sigmas = [ds for ds in dir_sigmas if ds.find(transformation) == 0]
    sigmas = [float(ds[len(transformation) + 1:]) for ds in dir_sigmas]
    idx_sigmas = np.argsort(sigmas)
    sigmas = [sigmas[i] for i in idx_sigmas]
    dir_sigmas = [dir_sigmas[i] for i in idx_sigmas]

    sigma_miss_err = {}
    sigma_times = {'PM': {}, 'NMU': {}, 'TOTAL': {}}
    example_miss_err = {}
    res_files = ['{}/{}/test.txt'.format(res_dir, ds) for ds in dir_sigmas]

    # Very crude parser, do not change console printing output
    # or this will break
    for s, rf in zip(sigmas, res_files):
        with open(rf, 'r') as file_contents:
            sigma_miss_err[s] = []
            sigma_times['PM'][s] = []
            sigma_times['NMU'][s] = []
            sigma_times['TOTAL'][s] = []
            for i, line in enumerate(file_contents):
                if line.find('Statistics') == 0:
                    break
                if i % 10 == 0:
                    example = line[:-5]
                if i % 10 == 3:
                    t = float(line.split()[4])
                    sigma_times['PM'][s].append(t)
                if i % 10 == 4:
                    t = float(line.split()[2])
                    sigma_times['NMU'][s].append(t)
                if i % 10 == 7:
                    t = float(line.split()[2])
                    sigma_times['TOTAL'][s].append(t)
                if i % 10 == 8:
                    pr = 100 * float(line.split()[3][:-1])
                    if example not in example_miss_err:
                        example_miss_err[example] = []
                    example_miss_err[example].append(pr)
                    sigma_miss_err[s].append(pr)

    def sort_dict(d):
        return collections.OrderedDict(sorted(d.items()))

    example_miss_err = sort_dict(example_miss_err)
    sigma_miss_err = sort_dict(sigma_miss_err)
    sigma_times['PM'] = sort_dict(sigma_times['PM'])
    sigma_times['NMU'] = sort_dict(sigma_times['NMU'])
    sigma_times['TOTAL'] = sort_dict(sigma_times['TOTAL'])

    def round2(vals, decimals=2):
        return np.round(vals, decimals=decimals)

    print('Misclassification error')
    for key in sigma_miss_err:
        values = np.array(sigma_miss_err[key])
        stats = (key, round2(np.mean(values)), round2(np.median(values)),
                 round2(np.std(values, ddof=1)))
        fmt_str = 'sigma: {}\tmean: {}\tmedian: {}\tstd: {}'
        print(fmt_str.format(*stats))
        # print('\t', values)

    with sns.axes_style("whitegrid"):
        values = np.array(list(sigma_miss_err.values())).T
        max_val = values.max()

        plt.figure()
        sns.boxplot(data=values, color='.95', whis=100)
        sns.stripplot(data=values, jitter=True)
        sigmas_text = ['{:.2f}'.format(s) for s in sigmas]
        plt.xticks(range(len(sigmas)), sigmas_text, size='x-large')
        yticks = [yt for yt in plt.yticks()[0] if yt >= 0]
        plt.yticks(yticks, size='x-large')
        plt.xlabel(r'$\sigma$', size='x-large')
        plt.ylabel('Misclassification error (%)', size='x-large')
        plt.ylim((-2, 10 * np.ceil(max_val / 10)))
        if transformation == 'homography':
            plt.title('Homographies', size='x-large')
        if transformation == 'fundamental':
            plt.title('Fundamental matrices', size='x-large')
        plt.tight_layout()
        plt.savefig('{}/{}_result.pdf'.format(res_dir, transformation),
                    bbox_inches='tight')

    print('Time')
    for key in sigma_miss_err:
        mean_PM = round2(np.mean(np.array(sigma_times['PM'][key])))
        mean_NMU = round2(np.mean((np.array(sigma_times['NMU'][key]))))
        mean_total = round2(np.mean((np.array(sigma_times['TOTAL'][key]))))
        stats = (key, mean_total, round2(mean_PM / mean_total),
                 round2(mean_NMU / mean_total))
        fmt_str = 'sigma: {}\tTOTAL: {}\tRATIO PM: {}\tRATIO NMU: {}'
        print(fmt_str.format(*stats))