Example #1
0
def test_get_single_plot():
    expected = Figure(data=Data(),
                      layout=Layout(xaxis1=XAxis(domain=[0.0, 1.0],
                                                 anchor='y1'),
                                    yaxis1=YAxis(domain=[0.0, 1.0],
                                                 anchor='x1')))
    assert tls.get_subplots() == expected
Example #2
0
def test_spacing():
    expected = Figure(
        data=Data(),
        layout=Layout(
            xaxis1=XAxis(
                domain=[0.0, 0.3],
                anchor='y1'
            ),
            xaxis2=XAxis(
                domain=[0.35, 0.6499999999999999],
                anchor='y2'
            ),
            xaxis3=XAxis(
                domain=[0.7, 1.0],
                anchor='y3'
            ),
            xaxis4=XAxis(
                domain=[0.0, 0.3],
                anchor='y4'
            ),
            xaxis5=XAxis(
                domain=[0.35, 0.6499999999999999],
                anchor='y5'
            ),
            xaxis6=XAxis(
                domain=[0.7, 1.0],
                anchor='y6'
            ),
            yaxis1=YAxis(
                domain=[0.0, 0.45],
                anchor='x1'
            ),
            yaxis2=YAxis(
                domain=[0.0, 0.45],
                anchor='x2'
            ),
            yaxis3=YAxis(
                domain=[0.0, 0.45],
                anchor='x3'
            ),
            yaxis4=YAxis(
                domain=[0.55, 1.0],
                anchor='x4'
            ),
            yaxis5=YAxis(
                domain=[0.55, 1.0],
                anchor='x5'
            ),
            yaxis6=YAxis(
                domain=[0.55, 1.0],
                anchor='x6'
            )
        )
    )

    fig = tls.get_subplots(2, 3,
                           horizontal_spacing=.05,
                           vertical_spacing=.1)

    assert fig == expected
Example #3
0
def test_spacing():
    expected = Figure(
        data=Data(),
        layout=Layout(
            xaxis1=XAxis(
                domain=[0.0, 0.3],
                anchor='y1'
            ),
            xaxis2=XAxis(
                domain=[0.35, 0.6499999999999999],
                anchor='y2'
            ),
            xaxis3=XAxis(
                domain=[0.7, 1.0],
                anchor='y3'
            ),
            xaxis4=XAxis(
                domain=[0.0, 0.3],
                anchor='y4'
            ),
            xaxis5=XAxis(
                domain=[0.35, 0.6499999999999999],
                anchor='y5'
            ),
            xaxis6=XAxis(
                domain=[0.7, 1.0],
                anchor='y6'
            ),
            yaxis1=YAxis(
                domain=[0.0, 0.45],
                anchor='x1'
            ),
            yaxis2=YAxis(
                domain=[0.0, 0.45],
                anchor='x2'
            ),
            yaxis3=YAxis(
                domain=[0.0, 0.45],
                anchor='x3'
            ),
            yaxis4=YAxis(
                domain=[0.55, 1.0],
                anchor='x4'
            ),
            yaxis5=YAxis(
                domain=[0.55, 1.0],
                anchor='x5'
            ),
            yaxis6=YAxis(
                domain=[0.55, 1.0],
                anchor='x6'
            )
        )
    )

    assert expected == tls.get_subplots(2, 3, .05, .1)
Example #4
0
def test_two_row():
    expected = Figure(data=Data(),
                      layout=Layout(xaxis1=XAxis(domain=[0.0, 1.0],
                                                 anchor='y1'),
                                    xaxis2=XAxis(domain=[0.0, 1.0],
                                                 anchor='y2'),
                                    yaxis1=YAxis(domain=[0.0, 0.425],
                                                 anchor='x1'),
                                    yaxis2=YAxis(domain=[0.575, 1.0],
                                                 anchor='x2')))
    assert tls.get_subplots(2) == expected
Example #5
0
def test_get_single_plot():
    expected = Figure(
        data=Data(),
        layout=Layout(
            xaxis1=XAxis(
                domain=[0.0, 1.0],
                anchor='y1'
            ),
            yaxis1=YAxis(
                domain=[0.0, 1.0],
                anchor='x1'
            )
        )
    )
    assert tls.get_subplots() == expected
Example #6
0
    def _plot_plotly(self, layout):
        import plotly.tools as tls
        from copy import deepcopy

        axis_style_empty = dict(title='',
                                showline=False,
                                showgrid=False,
                                zeroline=False,
                                showticklabels=False)

        fig = tls.get_subplots(rows=self.rows,
                               columns=self.columns,
                               horizontal_spacing=0.3 / self.columns,
                               vertical_spacing=0.3 / self.rows)
        splts, splts_empty = self._get_splts(self.rows, self.columns,
                                             len(self.plots))

        for index, plotter in zip(splts, self.plots):
            fig_one = plotter._plot_plotly(deepcopy(layout))
            for data in fig_one['data']:
                data['xaxis'] = 'x%d' % index
                data['yaxis'] = 'y%d' % index

            fig_one['layout']['xaxis%d' % index] = fig_one['layout']['xaxis']
            fig_one['layout']['yaxis%d' % index] = fig_one['layout']['yaxis']
            fig_one['layout']['xaxis%d' % index].update(
                anchor='y%d' % index,
                title=plotter.xlabel + '<br>%s' % plotter.title)

            fig_one['layout']['yaxis%d' % index].update(title=plotter.ylabel)

            # fig_one['layout']['title%d' % index] = plotter.title

            fig_one['layout'].pop('xaxis')
            fig_one['layout'].pop('yaxis')
            fig_one['layout'].pop('title')

            fig['data'] += fig_one['data']

            fig['layout'].update(fig_one['layout'])

        for index in splts_empty:
            fig['layout']['xaxis{}'.format(index)].update(axis_style_empty)
            fig['layout']['yaxis{}'.format(index)].update(axis_style_empty)
        fig['layout'].update(autosize=False)
        return fig
Example #7
0
    def _plot_plotly(self, layout):
        import plotly.tools as tls
        from copy import deepcopy

        axis_style_empty = dict(
            title='',
            showline=False,
            showgrid=False,
            zeroline=False,
            showticklabels=False
        )

        fig = tls.get_subplots(rows=self.rows, columns=self.columns, horizontal_spacing=0.3 / self.columns,
                               vertical_spacing=0.3 / self.rows)
        splts, splts_empty = self._get_splts(self.rows, self.columns, len(self.plots))

        for index, plotter in zip(splts, self.plots):
            fig_one = plotter._plot_plotly(deepcopy(layout))
            for data in fig_one['data']:
                data['xaxis'] = 'x%d' % index
                data['yaxis'] = 'y%d' % index

            fig_one['layout']['xaxis%d' % index] = fig_one['layout']['xaxis']
            fig_one['layout']['yaxis%d' % index] = fig_one['layout']['yaxis']
            fig_one['layout']['xaxis%d' % index].update(anchor='y%d' % index,
                                                        title=plotter.xlabel + '<br>%s' % plotter.title)

            fig_one['layout']['yaxis%d' % index].update(title=plotter.ylabel)

            # fig_one['layout']['title%d' % index] = plotter.title

            fig_one['layout'].pop('xaxis')
            fig_one['layout'].pop('yaxis')
            fig_one['layout'].pop('title')

            fig['data'] += fig_one['data']

            fig['layout'].update(fig_one['layout'])

        for index in splts_empty:
            fig['layout']['xaxis{}'.format(index)].update(axis_style_empty)
            fig['layout']['yaxis{}'.format(index)].update(axis_style_empty)
        fig['layout'].update(autosize=False)
        return fig
Example #8
0
def test_two_row():
    expected = Figure(
        data=Data(),
        layout=Layout(
            xaxis1=XAxis(
                domain=[0.0, 1.0],
                anchor='y1'
            ),
            xaxis2=XAxis(
                domain=[0.0, 1.0],
                anchor='y2'
            ),
            yaxis1=YAxis(
                domain=[0.0, 0.425],
                anchor='x1'
            ),
            yaxis2=YAxis(
                domain=[0.575, 1.0],
                anchor='x2'
            )
        )
    )
    assert tls.get_subplots(2) == expected
Example #9
0
def test_non_integer_rows():
    tls.get_subplots(rows=2.1)
)


# Set title

layout = Layout(
    title='Raspberry Pi Streaming Temperature Boxes'
)

#Plotting one figure
#fig = Figure(data=[measured_temp, sample_est], layout=layout)


# Plotting subplots
data = [target, measured_temp, heater_est, air_est, sample_est, heaterPWM, target2, measured_temp2, heater_est2, air_est2, sample_est2, heaterPWM2]
fig = tls.get_subplots(rows=2, columns=2)
fig['data'] += data
fig['layout'].update(title='Raspberry Pi Streaming Temperature Boxes')
fig['layout'].update(showlegend=True)

print py.plot(fig, filename='Raspberry Pi Streaming Temperature Boxes')

# Open a stream for each data structure

stream1 = py.Stream(stream_tokens[0])
stream1.open()
stream2 = py.Stream(stream_tokens[1])
stream2.open()
stream3 = py.Stream(stream_tokens[2])
stream3.open()
stream4 = py.Stream(stream_tokens[3])
                      xanchor='center',
                      align='center',
                      font=Font(size=14),
                      showarrow=False,
                      xref='x{}'.format(sbplt_in),
                      yref='y{}'.format(sbplt_in))


# Generate figure object with subplot layout:

# In[9]:

#
figure = tls.get_subplots(rows=len(datasets),
                          columns=len(classifiers) + 1,
                          horizontal_spacing=0.01,
                          vertical_spacing=0.05,
                          print_grid=True)

# Add a few style options:

# In[10]:

#
figure['layout'].update(showlegend=False,
                        hovermode='closest',
                        autosize=False,
                        width=1472,
                        height=490)

#
Example #12
0
def test_a_lot():
    expected = Figure(
        data=Data(),
        layout=Layout(
            xaxis1=XAxis(domain=[0.0, 0.05714285714285713], anchor='y1'),
            xaxis10=XAxis(domain=[0.3142857142857143, 0.3714285714285714],
                          anchor='y10'),
            xaxis11=XAxis(domain=[0.4714285714285714, 0.5285714285714286],
                          anchor='y11'),
            xaxis12=XAxis(domain=[0.6285714285714286, 0.6857142857142857],
                          anchor='y12'),
            xaxis13=XAxis(domain=[0.7857142857142857, 0.8428571428571429],
                          anchor='y13'),
            xaxis14=XAxis(domain=[0.9428571428571428, 1.0], anchor='y14'),
            xaxis15=XAxis(domain=[0.0, 0.05714285714285713], anchor='y15'),
            xaxis16=XAxis(domain=[0.15714285714285714, 0.21428571428571427],
                          anchor='y16'),
            xaxis17=XAxis(domain=[0.3142857142857143, 0.3714285714285714],
                          anchor='y17'),
            xaxis18=XAxis(domain=[0.4714285714285714, 0.5285714285714286],
                          anchor='y18'),
            xaxis19=XAxis(domain=[0.6285714285714286, 0.6857142857142857],
                          anchor='y19'),
            xaxis2=XAxis(domain=[0.15714285714285714, 0.21428571428571427],
                         anchor='y2'),
            xaxis20=XAxis(domain=[0.7857142857142857, 0.8428571428571429],
                          anchor='y20'),
            xaxis21=XAxis(domain=[0.9428571428571428, 1.0], anchor='y21'),
            xaxis22=XAxis(domain=[0.0, 0.05714285714285713], anchor='y22'),
            xaxis23=XAxis(domain=[0.15714285714285714, 0.21428571428571427],
                          anchor='y23'),
            xaxis24=XAxis(domain=[0.3142857142857143, 0.3714285714285714],
                          anchor='y24'),
            xaxis25=XAxis(domain=[0.4714285714285714, 0.5285714285714286],
                          anchor='y25'),
            xaxis26=XAxis(domain=[0.6285714285714286, 0.6857142857142857],
                          anchor='y26'),
            xaxis27=XAxis(domain=[0.7857142857142857, 0.8428571428571429],
                          anchor='y27'),
            xaxis28=XAxis(domain=[0.9428571428571428, 1.0], anchor='y28'),
            xaxis3=XAxis(domain=[0.3142857142857143, 0.3714285714285714],
                         anchor='y3'),
            xaxis4=XAxis(domain=[0.4714285714285714, 0.5285714285714286],
                         anchor='y4'),
            xaxis5=XAxis(domain=[0.6285714285714286, 0.6857142857142857],
                         anchor='y5'),
            xaxis6=XAxis(domain=[0.7857142857142857, 0.8428571428571429],
                         anchor='y6'),
            xaxis7=XAxis(domain=[0.9428571428571428, 1.0], anchor='y7'),
            xaxis8=XAxis(domain=[0.0, 0.05714285714285713], anchor='y8'),
            xaxis9=XAxis(domain=[0.15714285714285714, 0.21428571428571427],
                         anchor='y9'),
            yaxis1=YAxis(domain=[0.0, 0.1375], anchor='x1'),
            yaxis10=YAxis(domain=[0.2875, 0.425], anchor='x10'),
            yaxis11=YAxis(domain=[0.2875, 0.425], anchor='x11'),
            yaxis12=YAxis(domain=[0.2875, 0.425], anchor='x12'),
            yaxis13=YAxis(domain=[0.2875, 0.425], anchor='x13'),
            yaxis14=YAxis(domain=[0.2875, 0.425], anchor='x14'),
            yaxis15=YAxis(domain=[0.575, 0.7124999999999999], anchor='x15'),
            yaxis16=YAxis(domain=[0.575, 0.7124999999999999], anchor='x16'),
            yaxis17=YAxis(domain=[0.575, 0.7124999999999999], anchor='x17'),
            yaxis18=YAxis(domain=[0.575, 0.7124999999999999], anchor='x18'),
            yaxis19=YAxis(domain=[0.575, 0.7124999999999999], anchor='x19'),
            yaxis2=YAxis(domain=[0.0, 0.1375], anchor='x2'),
            yaxis20=YAxis(domain=[0.575, 0.7124999999999999], anchor='x20'),
            yaxis21=YAxis(domain=[0.575, 0.7124999999999999], anchor='x21'),
            yaxis22=YAxis(domain=[0.8624999999999999, 1.0], anchor='x22'),
            yaxis23=YAxis(domain=[0.8624999999999999, 1.0], anchor='x23'),
            yaxis24=YAxis(domain=[0.8624999999999999, 1.0], anchor='x24'),
            yaxis25=YAxis(domain=[0.8624999999999999, 1.0], anchor='x25'),
            yaxis26=YAxis(domain=[0.8624999999999999, 1.0], anchor='x26'),
            yaxis27=YAxis(domain=[0.8624999999999999, 1.0], anchor='x27'),
            yaxis28=YAxis(domain=[0.8624999999999999, 1.0], anchor='x28'),
            yaxis3=YAxis(domain=[0.0, 0.1375], anchor='x3'),
            yaxis4=YAxis(domain=[0.0, 0.1375], anchor='x4'),
            yaxis5=YAxis(domain=[0.0, 0.1375], anchor='x5'),
            yaxis6=YAxis(domain=[0.0, 0.1375], anchor='x6'),
            yaxis7=YAxis(domain=[0.0, 0.1375], anchor='x7'),
            yaxis8=YAxis(domain=[0.2875, 0.425], anchor='x8'),
            yaxis9=YAxis(domain=[0.2875, 0.425], anchor='x9')))

    fig = tls.get_subplots(4, 7, horizontal_spacing=0.1, vertical_spacing=0.15)

    assert fig == expected
Example #13
0
def test_wrong_kwarg():
    tls.get_subplots(stuff='no gonna work')
        xanchor='center',
        align='center',
        font= Font(size=14),
        showarrow=False,
        xref= 'x{}'.format(sbplt_in),
        yref= 'y{}'.format(sbplt_in))


# Generate figure object with subplot layout:

# In[9]:

#
figure = tls.get_subplots(
    rows=len(datasets),
    columns=len(classifiers)+1,
    horizontal_spacing=0.01,
    vertical_spacing=0.05,
    print_grid=True)


# Add a few style options:

# In[10]:

#
figure['layout'].update(showlegend=False,
                        hovermode='closest',
                        autosize=False,
                        width=1472,
                        height=490)
Example #15
0
def test_default_spacing():
    expected = Figure(
        data=Data(),
        layout=Layout(
            xaxis1=XAxis(domain=[0.0, 0.16799999999999998], anchor='y1'),
            xaxis10=XAxis(domain=[0.832, 1.0], anchor='y10'),
            xaxis11=XAxis(domain=[0.0, 0.16799999999999998], anchor='y11'),
            xaxis12=XAxis(domain=[0.208, 0.376], anchor='y12'),
            xaxis13=XAxis(domain=[0.416, 0.584], anchor='y13'),
            xaxis14=XAxis(domain=[0.624, 0.792], anchor='y14'),
            xaxis15=XAxis(domain=[0.832, 1.0], anchor='y15'),
            xaxis16=XAxis(domain=[0.0, 0.16799999999999998], anchor='y16'),
            xaxis17=XAxis(domain=[0.208, 0.376], anchor='y17'),
            xaxis18=XAxis(domain=[0.416, 0.584], anchor='y18'),
            xaxis19=XAxis(domain=[0.624, 0.792], anchor='y19'),
            xaxis2=XAxis(domain=[0.208, 0.376], anchor='y2'),
            xaxis20=XAxis(domain=[0.832, 1.0], anchor='y20'),
            xaxis21=XAxis(domain=[0.0, 0.16799999999999998], anchor='y21'),
            xaxis22=XAxis(domain=[0.208, 0.376], anchor='y22'),
            xaxis23=XAxis(domain=[0.416, 0.584], anchor='y23'),
            xaxis24=XAxis(domain=[0.624, 0.792], anchor='y24'),
            xaxis25=XAxis(domain=[0.832, 1.0], anchor='y25'),
            xaxis26=XAxis(domain=[0.0, 0.16799999999999998], anchor='y26'),
            xaxis27=XAxis(domain=[0.208, 0.376], anchor='y27'),
            xaxis28=XAxis(domain=[0.416, 0.584], anchor='y28'),
            xaxis29=XAxis(domain=[0.624, 0.792], anchor='y29'),
            xaxis3=XAxis(domain=[0.416, 0.584], anchor='y3'),
            xaxis30=XAxis(domain=[0.832, 1.0], anchor='y30'),
            xaxis4=XAxis(domain=[0.624, 0.792], anchor='y4'),
            xaxis5=XAxis(domain=[0.832, 1.0], anchor='y5'),
            xaxis6=XAxis(domain=[0.0, 0.16799999999999998], anchor='y6'),
            xaxis7=XAxis(domain=[0.208, 0.376], anchor='y7'),
            xaxis8=XAxis(domain=[0.416, 0.584], anchor='y8'),
            xaxis9=XAxis(domain=[0.624, 0.792], anchor='y9'),
            yaxis1=YAxis(domain=[0.0, 0.125], anchor='x1'),
            yaxis10=YAxis(domain=[0.175, 0.3], anchor='x10'),
            yaxis11=YAxis(domain=[0.35, 0.475], anchor='x11'),
            yaxis12=YAxis(domain=[0.35, 0.475], anchor='x12'),
            yaxis13=YAxis(domain=[0.35, 0.475], anchor='x13'),
            yaxis14=YAxis(domain=[0.35, 0.475], anchor='x14'),
            yaxis15=YAxis(domain=[0.35, 0.475], anchor='x15'),
            yaxis16=YAxis(domain=[0.5249999999999999, 0.6499999999999999],
                          anchor='x16'),
            yaxis17=YAxis(domain=[0.5249999999999999, 0.6499999999999999],
                          anchor='x17'),
            yaxis18=YAxis(domain=[0.5249999999999999, 0.6499999999999999],
                          anchor='x18'),
            yaxis19=YAxis(domain=[0.5249999999999999, 0.6499999999999999],
                          anchor='x19'),
            yaxis2=YAxis(domain=[0.0, 0.125], anchor='x2'),
            yaxis20=YAxis(domain=[0.5249999999999999, 0.6499999999999999],
                          anchor='x20'),
            yaxis21=YAxis(domain=[0.7, 0.825], anchor='x21'),
            yaxis22=YAxis(domain=[0.7, 0.825], anchor='x22'),
            yaxis23=YAxis(domain=[0.7, 0.825], anchor='x23'),
            yaxis24=YAxis(domain=[0.7, 0.825], anchor='x24'),
            yaxis25=YAxis(domain=[0.7, 0.825], anchor='x25'),
            yaxis26=YAxis(domain=[0.875, 1.0], anchor='x26'),
            yaxis27=YAxis(domain=[0.875, 1.0], anchor='x27'),
            yaxis28=YAxis(domain=[0.875, 1.0], anchor='x28'),
            yaxis29=YAxis(domain=[0.875, 1.0], anchor='x29'),
            yaxis3=YAxis(domain=[0.0, 0.125], anchor='x3'),
            yaxis30=YAxis(domain=[0.875, 1.0], anchor='x30'),
            yaxis4=YAxis(domain=[0.0, 0.125], anchor='x4'),
            yaxis5=YAxis(domain=[0.0, 0.125], anchor='x5'),
            yaxis6=YAxis(domain=[0.175, 0.3], anchor='x6'),
            yaxis7=YAxis(domain=[0.175, 0.3], anchor='x7'),
            yaxis8=YAxis(domain=[0.175, 0.3], anchor='x8'),
            yaxis9=YAxis(domain=[0.175, 0.3], anchor='x9')))

    fig = tls.get_subplots(rows=6, columns=5)

    assert fig == expected
Example #16
0
def test_non_integer_columns():
    tls.get_subplots(columns=2/3)
Example #17
0
def test_non_integer_rows():
    tls.get_subplots(rows=2.1)
Example #18
0
py.iplot(data, filename='apple stock moving average')

first_plot_url = py.plot(data, filename='apple stock moving average', auto_open=False,)
print first_plot_url

tickers = ['AAPL', 'GE', 'IBM', 'KO', 'MSFT', 'PEP']
prices = []
for ticker in tickers:
    quotes = quotes_historical_yahoo(ticker, date1, date2)
    prices.append( [q[1] for q in quotes] )

df = pd.DataFrame( prices ).transpose()
df.columns = tickers
df.head()

fig = plotly_tools.get_subplots(rows=6, columns=6, print_grid=True, horizontal_spacing= 0.05, vertical_spacing= 0.05)

    """Kernel Density Estimation with Scipy"""
    # From https://jakevdp.github.io/blog/2013/12/01/kernel-density-estimation/
    # Note that scipy weights its bandwidth by the covariance of the
    # input data.  To make the results comparable to the other methods,
    # we divide the bandwidth by the sample standard deviation here.
    kde = gaussian_kde(x, bw_method=bandwidth / x.std(ddof=1), **kwargs)
    return kde.evaluate(x_grid)

subplots = range(1,37)
sp_index = 0
data = []
for i in range(1,7):
    x_ticker = df.columns[i-1]
    for j in range(1,7):
def scatter_matrix(df, marker_color=None,
                   marker_text=None,
                   figsize=(1024,1024), 
                   title='', 
                   outfile=None,
                   cmap='deep', ):
    """
    Plot pairwise relationships between all columns in the dataframe
    <df>
    
    Paramters
    ---------
    df: pandas dataframe
        Data to visualize
    
    marker_color: can be a) a string indicating a column in <df>; b) an 
            array-like of values, each corresponding to a row in <df>,
            or c) None, for which each variable gets its own color
    
    title: str
        The figure title
    
    figsize: tuple
        The (height, width) of the figure
        
    cmap: str
        The name of the colormap used (i.e. colorbrewer)
    
    outfile: filepath str
        If provided, output to an HTML file at provided location
        
    
    Example
    -------
    from sklearn.datasets import make_classification
    N_FEATURES = 4
    X, y = make_classification(n_samples=100, n_clusters_per_class=1, n_classes=4, n_features=N_FEATURES)

    df = pd.DataFrame(X, columns=['feature_%d' % f for f in range(N_FEATURES)])
    df['class'] = y
    
    scatter_matrix(df, colors='class', title='Scatter Matrix')

    """
    
    if hasattr(marker_text, '__call__'):
        texts = marker_text(df)
    elif isinstance(marker_text, str):
        texts = df[marker_text].values.tolist()
    elif hasattr(marker_text, '__iter__'):
        texts = marker_text
    
    if isinstance(marker_color, str):
        tmp = marker_color
        marker_color = df[marker_color].values.tolist()
        del df[tmp]
    
    columns = df.columns
    n_columns = len(columns)
    
    alpha = .5
    
    if hasattr(marker_color, '__call__'):  # function provided
        colors = marker_color(df)
        
    elif hasattr(marker_color, '__iter__'):
        if isinstance(marker_color[0], str):  # array-like of strings (i.e. categories) 
            unique_vals = np.unique(marker_color).tolist()
            n_colors = len(unique_vals)
            
            # set up colors
            palette = sns.color_palette(cmap, n_colors)
            rgb = [[int(p*256) for p in color] for color in palette]

            colors = []
            
            for c in marker_color:
                colors.append(rgb[unique_vals.index(c)])
                
            colors = ['rgba(' + ','.join([str(c) for c in color])+',%1.2f)'% alpha for color in colors]
        else:
            colors = marker_color
        
    if marker_color is None:

        # set up colors
        palette = sns.color_palette(cmap, n_columns + 1)

        # rescale to 0 - 256
        rgb = [[int(p*256) for p in color] for color in palette]
        marker_rgba = [r+[alpha] for r in rgb]

    
    subplots = range(1,n_columns**2 + 1)
    subplot_idx = 0
    data = []
    # setup subplots
    fig = get_subplots(rows=n_columns, columns=n_columns, 
                       horizontal_spacing=0.05, 
                       vertical_spacing=0.05)

    for i in range(1,n_columns + 1):
        
        if marker_color is None:
            row_color ='rgba('+ ','.join([str(v) for v in marker_rgba[i]]) +')'
            scatter_color = row_color
        else:
            row_color = 'gray'
            scatter_color = colors
            
        
        x_column = df.columns[i-1]
        
        for j in range(1, n_columns + 1):
            y_column = df.columns[j-1]
            if i==j:  # plot histogram and kde along diagonal

                x = df[x_column]
                x_grid = np.linspace(x.min(), x.max(), 100)
                sub_plot = [go.Histogram(x=x, histnorm='probability density', 
                                         marker=go.Marker(color=row_color)), \
                            go.Scatter(x=x_grid, y=kde(x.as_matrix(), x_grid), \
                            line=go.Line(width=2, color='black'))]
            
            else:  # scatter plot
                sub_plot = [go.Scatter(y=df[x_column], x=df[y_column], 
                                    mode='markers',
                                    marker=go.Marker(size=6, 
                                                     color=scatter_color,
                                                     colorscale=cmap))]  # colorscale gets ignore if rgba() provided

            # set text for each datapoint
            for pt in sub_plot:
                pt.update(xaxis='x{}'.format(subplots[subplot_idx]),\
                          yaxis='y{}'.format(subplots[subplot_idx]), \
                          name='{0}'.format(x_column))
                    
                if i!=j:
                    if texts:
                        pt.update(text=texts)
                    else:
                        pt.update(text='{0}<br>vs<br>{1}'.format(y_column,x_column))
                
            subplot_idx += 1
            data += sub_plot

    # add x and y labels
    left_index = 1
    bottom_index = 1
    for col in df.columns:
        fig['layout']['xaxis{}'.format(left_index)].update(title=col)
        fig['layout']['yaxis{}'.format(bottom_index)].update(title=col)
        left_index=left_index + 1
        bottom_index=bottom_index + n_columns

    # Remove legend by updating 'layout' key
    fig['layout'].update(showlegend=False, height=figsize[1],width=figsize[0],title=title)
    fig['data'] = go.Data(data)
    ol.iplot(fig, show_link=False)

    # write figure to HTML file
    if outfile:
        print('Exporting copy of figure to %s...' % outfile)
        ol.plot(fig, auto_open=False, filename=outfile)
Example #20
0
def test_default_spacing():
    expected = Figure(
        data=Data(),
        layout=Layout(
            xaxis1=XAxis(
                domain=[0.0, 0.16799999999999998],
                anchor='y1'
            ),
            xaxis10=XAxis(
                domain=[0.832, 1.0],
                anchor='y10'
            ),
            xaxis11=XAxis(
                domain=[0.0, 0.16799999999999998],
                anchor='y11'
            ),
            xaxis12=XAxis(
                domain=[0.208, 0.376],
                anchor='y12'
            ),
            xaxis13=XAxis(
                domain=[0.416, 0.584],
                anchor='y13'
            ),
            xaxis14=XAxis(
                domain=[0.624, 0.792],
                anchor='y14'
            ),
            xaxis15=XAxis(
                domain=[0.832, 1.0],
                anchor='y15'
            ),
            xaxis16=XAxis(
                domain=[0.0, 0.16799999999999998],
                anchor='y16'
            ),
            xaxis17=XAxis(
                domain=[0.208, 0.376],
                anchor='y17'
            ),
            xaxis18=XAxis(
                domain=[0.416, 0.584],
                anchor='y18'
            ),
            xaxis19=XAxis(
                domain=[0.624, 0.792],
                anchor='y19'
            ),
            xaxis2=XAxis(
                domain=[0.208, 0.376],
                anchor='y2'
            ),
            xaxis20=XAxis(
                domain=[0.832, 1.0],
                anchor='y20'
            ),
            xaxis21=XAxis(
                domain=[0.0, 0.16799999999999998],
                anchor='y21'
            ),
            xaxis22=XAxis(
                domain=[0.208, 0.376],
                anchor='y22'
            ),
            xaxis23=XAxis(
                domain=[0.416, 0.584],
                anchor='y23'
            ),
            xaxis24=XAxis(
                domain=[0.624, 0.792],
                anchor='y24'
            ),
            xaxis25=XAxis(
                domain=[0.832, 1.0],
                anchor='y25'
            ),
            xaxis26=XAxis(
                domain=[0.0, 0.16799999999999998],
                anchor='y26'
            ),
            xaxis27=XAxis(
                domain=[0.208, 0.376],
                anchor='y27'
            ),
            xaxis28=XAxis(
                domain=[0.416, 0.584],
                anchor='y28'
            ),
            xaxis29=XAxis(
                domain=[0.624, 0.792],
                anchor='y29'
            ),
            xaxis3=XAxis(
                domain=[0.416, 0.584],
                anchor='y3'
            ),
            xaxis30=XAxis(
                domain=[0.832, 1.0],
                anchor='y30'
            ),
            xaxis4=XAxis(
                domain=[0.624, 0.792],
                anchor='y4'
            ),
            xaxis5=XAxis(
                domain=[0.832, 1.0],
                anchor='y5'
            ),
            xaxis6=XAxis(
                domain=[0.0, 0.16799999999999998],
                anchor='y6'
            ),
            xaxis7=XAxis(
                domain=[0.208, 0.376],
                anchor='y7'
            ),
            xaxis8=XAxis(
                domain=[0.416, 0.584],
                anchor='y8'
            ),
            xaxis9=XAxis(
                domain=[0.624, 0.792],
                anchor='y9'
            ),
            yaxis1=YAxis(
                domain=[0.0, 0.125],
                anchor='x1'
            ),
            yaxis10=YAxis(
                domain=[0.175, 0.3],
                anchor='x10'
            ),
            yaxis11=YAxis(
                domain=[0.35, 0.475],
                anchor='x11'
            ),
            yaxis12=YAxis(
                domain=[0.35, 0.475],
                anchor='x12'
            ),
            yaxis13=YAxis(
                domain=[0.35, 0.475],
                anchor='x13'
            ),
            yaxis14=YAxis(
                domain=[0.35, 0.475],
                anchor='x14'
            ),
            yaxis15=YAxis(
                domain=[0.35, 0.475],
                anchor='x15'
            ),
            yaxis16=YAxis(
                domain=[0.5249999999999999, 0.6499999999999999],
                anchor='x16'
            ),
            yaxis17=YAxis(
                domain=[0.5249999999999999, 0.6499999999999999],
                anchor='x17'
            ),
            yaxis18=YAxis(
                domain=[0.5249999999999999, 0.6499999999999999],
                anchor='x18'
            ),
            yaxis19=YAxis(
                domain=[0.5249999999999999, 0.6499999999999999],
                anchor='x19'
            ),
            yaxis2=YAxis(
                domain=[0.0, 0.125],
                anchor='x2'
            ),
            yaxis20=YAxis(
                domain=[0.5249999999999999, 0.6499999999999999],
                anchor='x20'
            ),
            yaxis21=YAxis(
                domain=[0.7, 0.825],
                anchor='x21'
            ),
            yaxis22=YAxis(
                domain=[0.7, 0.825],
                anchor='x22'
            ),
            yaxis23=YAxis(
                domain=[0.7, 0.825],
                anchor='x23'
            ),
            yaxis24=YAxis(
                domain=[0.7, 0.825],
                anchor='x24'
            ),
            yaxis25=YAxis(
                domain=[0.7, 0.825],
                anchor='x25'
            ),
            yaxis26=YAxis(
                domain=[0.875, 1.0],
                anchor='x26'
            ),
            yaxis27=YAxis(
                domain=[0.875, 1.0],
                anchor='x27'
            ),
            yaxis28=YAxis(
                domain=[0.875, 1.0],
                anchor='x28'
            ),
            yaxis29=YAxis(
                domain=[0.875, 1.0],
                anchor='x29'
            ),
            yaxis3=YAxis(
                domain=[0.0, 0.125],
                anchor='x3'
            ),
            yaxis30=YAxis(
                domain=[0.875, 1.0],
                anchor='x30'
            ),
            yaxis4=YAxis(
                domain=[0.0, 0.125],
                anchor='x4'
            ),
            yaxis5=YAxis(
                domain=[0.0, 0.125],
                anchor='x5'
            ),
            yaxis6=YAxis(
                domain=[0.175, 0.3],
                anchor='x6'
            ),
            yaxis7=YAxis(
                domain=[0.175, 0.3],
                anchor='x7'
            ),
            yaxis8=YAxis(
                domain=[0.175, 0.3],
                anchor='x8'
            ),
            yaxis9=YAxis(
                domain=[0.175, 0.3],
                anchor='x9'
            )
        )
    )

    fig = tls.get_subplots(rows=6, columns=5)

    assert fig == expected
Example #21
0
def test_a_lot():
    expected = Figure(
        data=Data(),
        layout=Layout(
            xaxis1=XAxis(
                domain=[0.0, 0.05714285714285713],
                anchor='y1'
            ),
            xaxis10=XAxis(
                domain=[0.3142857142857143, 0.3714285714285714],
                anchor='y10'
            ),
            xaxis11=XAxis(
                domain=[0.4714285714285714, 0.5285714285714286],
                anchor='y11'
            ),
            xaxis12=XAxis(
                domain=[0.6285714285714286, 0.6857142857142857],
                anchor='y12'
            ),
            xaxis13=XAxis(
                domain=[0.7857142857142857, 0.8428571428571429],
                anchor='y13'
            ),
            xaxis14=XAxis(
                domain=[0.9428571428571428, 1.0],
                anchor='y14'
            ),
            xaxis15=XAxis(
                domain=[0.0, 0.05714285714285713],
                anchor='y15'
            ),
            xaxis16=XAxis(
                domain=[0.15714285714285714, 0.21428571428571427],
                anchor='y16'
            ),
            xaxis17=XAxis(
                domain=[0.3142857142857143, 0.3714285714285714],
                anchor='y17'
            ),
            xaxis18=XAxis(
                domain=[0.4714285714285714, 0.5285714285714286],
                anchor='y18'
            ),
            xaxis19=XAxis(
                domain=[0.6285714285714286, 0.6857142857142857],
                anchor='y19'
            ),
            xaxis2=XAxis(
                domain=[0.15714285714285714, 0.21428571428571427],
                anchor='y2'
            ),
            xaxis20=XAxis(
                domain=[0.7857142857142857, 0.8428571428571429],
                anchor='y20'
            ),
            xaxis21=XAxis(
                domain=[0.9428571428571428, 1.0],
                anchor='y21'
            ),
            xaxis22=XAxis(
                domain=[0.0, 0.05714285714285713],
                anchor='y22'
            ),
            xaxis23=XAxis(
                domain=[0.15714285714285714, 0.21428571428571427],
                anchor='y23'
            ),
            xaxis24=XAxis(
                domain=[0.3142857142857143, 0.3714285714285714],
                anchor='y24'
            ),
            xaxis25=XAxis(
                domain=[0.4714285714285714, 0.5285714285714286],
                anchor='y25'
            ),
            xaxis26=XAxis(
                domain=[0.6285714285714286, 0.6857142857142857],
                anchor='y26'
            ),
            xaxis27=XAxis(
                domain=[0.7857142857142857, 0.8428571428571429],
                anchor='y27'
            ),
            xaxis28=XAxis(
                domain=[0.9428571428571428, 1.0],
                anchor='y28'
            ),
            xaxis3=XAxis(
                domain=[0.3142857142857143, 0.3714285714285714],
                anchor='y3'
            ),
            xaxis4=XAxis(
                domain=[0.4714285714285714, 0.5285714285714286],
                anchor='y4'
            ),
            xaxis5=XAxis(
                domain=[0.6285714285714286, 0.6857142857142857],
                anchor='y5'
            ),
            xaxis6=XAxis(
                domain=[0.7857142857142857, 0.8428571428571429],
                anchor='y6'
            ),
            xaxis7=XAxis(
                domain=[0.9428571428571428, 1.0],
                anchor='y7'
            ),
            xaxis8=XAxis(
                domain=[0.0, 0.05714285714285713],
                anchor='y8'
            ),
            xaxis9=XAxis(
                domain=[0.15714285714285714, 0.21428571428571427],
                anchor='y9'
            ),
            yaxis1=YAxis(
                domain=[0.0, 0.1375],
                anchor='x1'
            ),
            yaxis10=YAxis(
                domain=[0.2875, 0.425],
                anchor='x10'
            ),
            yaxis11=YAxis(
                domain=[0.2875, 0.425],
                anchor='x11'
            ),
            yaxis12=YAxis(
                domain=[0.2875, 0.425],
                anchor='x12'
            ),
            yaxis13=YAxis(
                domain=[0.2875, 0.425],
                anchor='x13'
            ),
            yaxis14=YAxis(
                domain=[0.2875, 0.425],
                anchor='x14'
            ),
            yaxis15=YAxis(
                domain=[0.575, 0.7124999999999999],
                anchor='x15'
            ),
            yaxis16=YAxis(
                domain=[0.575, 0.7124999999999999],
                anchor='x16'
            ),
            yaxis17=YAxis(
                domain=[0.575, 0.7124999999999999],
                anchor='x17'
            ),
            yaxis18=YAxis(
                domain=[0.575, 0.7124999999999999],
                anchor='x18'
            ),
            yaxis19=YAxis(
                domain=[0.575, 0.7124999999999999],
                anchor='x19'
            ),
            yaxis2=YAxis(
                domain=[0.0, 0.1375],
                anchor='x2'
            ),
            yaxis20=YAxis(
                domain=[0.575, 0.7124999999999999],
                anchor='x20'
            ),
            yaxis21=YAxis(
                domain=[0.575, 0.7124999999999999],
                anchor='x21'
            ),
            yaxis22=YAxis(
                domain=[0.8624999999999999, 1.0],
                anchor='x22'
            ),
            yaxis23=YAxis(
                domain=[0.8624999999999999, 1.0],
                anchor='x23'
            ),
            yaxis24=YAxis(
                domain=[0.8624999999999999, 1.0],
                anchor='x24'
            ),
            yaxis25=YAxis(
                domain=[0.8624999999999999, 1.0],
                anchor='x25'
            ),
            yaxis26=YAxis(
                domain=[0.8624999999999999, 1.0],
                anchor='x26'
            ),
            yaxis27=YAxis(
                domain=[0.8624999999999999, 1.0],
                anchor='x27'
            ),
            yaxis28=YAxis(
                domain=[0.8624999999999999, 1.0],
                anchor='x28'
            ),
            yaxis3=YAxis(
                domain=[0.0, 0.1375],
                anchor='x3'
            ),
            yaxis4=YAxis(
                domain=[0.0, 0.1375],
                anchor='x4'
            ),
            yaxis5=YAxis(
                domain=[0.0, 0.1375],
                anchor='x5'
            ),
            yaxis6=YAxis(
                domain=[0.0, 0.1375],
                anchor='x6'
            ),
            yaxis7=YAxis(
                domain=[0.0, 0.1375],
                anchor='x7'
            ),
            yaxis8=YAxis(
                domain=[0.2875, 0.425],
                anchor='x8'
            ),
            yaxis9=YAxis(
                domain=[0.2875, 0.425],
                anchor='x9'
            )
        )
    )

    fig = tls.get_subplots(4, 7, horizontal_spacing=0.1, vertical_spacing=0.15)

    assert fig == expected
def scatter_matrix(
        df,
        marker_color=None,
        marker_text=None,
        figsize=(1024, 1024),
        title='',
        outfile=None,
        cmap='deep',
):
    """
    Plot pairwise relationships between all columns in the dataframe
    <df>
    
    Paramters
    ---------
    df: pandas dataframe
        Data to visualize
    
    marker_color: can be a) a string indicating a column in <df>; b) an 
            array-like of values, each corresponding to a row in <df>,
            or c) None, for which each variable gets its own color
    
    title: str
        The figure title
    
    figsize: tuple
        The (height, width) of the figure
        
    cmap: str
        The name of the colormap used (i.e. colorbrewer)
    
    outfile: filepath str
        If provided, output to an HTML file at provided location
        
    
    Example
    -------
    from sklearn.datasets import make_classification
    N_FEATURES = 4
    X, y = make_classification(n_samples=100, n_clusters_per_class=1, n_classes=4, n_features=N_FEATURES)

    df = pd.DataFrame(X, columns=['feature_%d' % f for f in range(N_FEATURES)])
    df['class'] = y
    
    scatter_matrix(df, colors='class', title='Scatter Matrix')

    """

    if hasattr(marker_text, '__call__'):
        texts = marker_text(df)
    elif isinstance(marker_text, str):
        texts = df[marker_text].values.tolist()
    elif hasattr(marker_text, '__iter__'):
        texts = marker_text

    if isinstance(marker_color, str):
        tmp = marker_color
        marker_color = df[marker_color].values.tolist()
        del df[tmp]

    columns = df.columns
    n_columns = len(columns)

    alpha = .5

    if hasattr(marker_color, '__call__'):  # function provided
        colors = marker_color(df)

    elif hasattr(marker_color, '__iter__'):
        if isinstance(marker_color[0],
                      str):  # array-like of strings (i.e. categories)
            unique_vals = np.unique(marker_color).tolist()
            n_colors = len(unique_vals)

            # set up colors
            palette = sns.color_palette(cmap, n_colors)
            rgb = [[int(p * 256) for p in color] for color in palette]

            colors = []

            for c in marker_color:
                colors.append(rgb[unique_vals.index(c)])

            colors = [
                'rgba(' + ','.join([str(c) for c in color]) + ',%1.2f)' % alpha
                for color in colors
            ]
        else:
            colors = marker_color

    if marker_color is None:

        # set up colors
        palette = sns.color_palette(cmap, n_columns + 1)

        # rescale to 0 - 256
        rgb = [[int(p * 256) for p in color] for color in palette]
        marker_rgba = [r + [alpha] for r in rgb]

    subplots = range(1, n_columns**2 + 1)
    subplot_idx = 0
    data = []
    # setup subplots
    fig = get_subplots(rows=n_columns,
                       columns=n_columns,
                       horizontal_spacing=0.05,
                       vertical_spacing=0.05)

    for i in range(1, n_columns + 1):

        if marker_color is None:
            row_color = 'rgba(' + ','.join([str(v)
                                            for v in marker_rgba[i]]) + ')'
            scatter_color = row_color
        else:
            row_color = 'gray'
            scatter_color = colors

        x_column = df.columns[i - 1]

        for j in range(1, n_columns + 1):
            y_column = df.columns[j - 1]
            if i == j:  # plot histogram and kde along diagonal

                x = df[x_column]
                x_grid = np.linspace(x.min(), x.max(), 100)
                sub_plot = [go.Histogram(x=x, histnorm='probability density',
                                         marker=go.Marker(color=row_color)), \
                            go.Scatter(x=x_grid, y=kde(x.as_matrix(), x_grid), \
                            line=go.Line(width=2, color='black'))]

            else:  # scatter plot
                sub_plot = [
                    go.Scatter(y=df[x_column],
                               x=df[y_column],
                               mode='markers',
                               marker=go.Marker(size=6,
                                                color=scatter_color,
                                                colorscale=cmap))
                ]  # colorscale gets ignore if rgba() provided

            # set text for each datapoint
            for pt in sub_plot:
                pt.update(xaxis='x{}'.format(subplots[subplot_idx]),\
                          yaxis='y{}'.format(subplots[subplot_idx]), \
                          name='{0}'.format(x_column))

                if i != j:
                    if texts:
                        pt.update(text=texts)
                    else:
                        pt.update(
                            text='{0}<br>vs<br>{1}'.format(y_column, x_column))

            subplot_idx += 1
            data += sub_plot

    # add x and y labels
    left_index = 1
    bottom_index = 1
    for col in df.columns:
        fig['layout']['xaxis{}'.format(left_index)].update(title=col)
        fig['layout']['yaxis{}'.format(bottom_index)].update(title=col)
        left_index = left_index + 1
        bottom_index = bottom_index + n_columns

    # Remove legend by updating 'layout' key
    fig['layout'].update(showlegend=False,
                         height=figsize[1],
                         width=figsize[0],
                         title=title)
    fig['data'] = go.Data(data)
    ol.iplot(fig, show_link=False)

    # write figure to HTML file
    if outfile:
        print('Exporting copy of figure to %s...' % outfile)
        ol.plot(fig, auto_open=False, filename=outfile)
Example #23
0
def test_non_integer_columns():
    tls.get_subplots(columns=2 / 3)
Example #24
0
import plotly.plotly as py
from plotly.graph_objs import *

py.sign_in('>>>username<<<', '>>>api_key<<<')

import plotly.tools as tls

trace1 = Bar(y=[1, 2, 3], xaxis='x1', yaxis='y1')
trace2 = Bar(y=[1, 2, 3], xaxis='x2', yaxis='y2')
trace3 = Bar(y=[1, 2, 3], xaxis='x3', yaxis='y3')
trace4 = Bar(y=[1, 2, 3], xaxis='x4', yaxis='y4')
data = Data([trace1, trace2, trace3, trace4])
fig = tls.get_subplots(rows=2, columns=2)
fig['data'] += data
fig['layout'].update(title='i <3 subplots')

plot_url = py.plot(fig, filename='>>>filename<<<')
Example #25
0
def test_wrong_kwarg():
    tls.get_subplots(stuff='no gonna work')