예제 #1
0
    def test_default_dendrogram(self):
        X = np.array([[1, 2, 3, 4], [1, 1, 3, 4], [1, 2, 1, 4], [1, 2, 3, 1]])
        dendro = tls.FigureFactory.create_dendrogram(X=X)

        expected_dendro = go.Figure(
            data=go.Data([
                go.Scatter(x=np.array([25., 25., 35., 35.]),
                           y=np.array([0., 1., 1., 0.]),
                           marker=go.Marker(color='rgb(61,153,112)'),
                           mode='lines',
                           xaxis='x',
                           yaxis='y'),
                go.Scatter(x=np.array([15., 15., 30., 30.]),
                           y=np.array([0., 2.23606798, 2.23606798, 1.]),
                           marker=go.Marker(color='rgb(61,153,112)'),
                           mode='lines',
                           xaxis='x',
                           yaxis='y'),
                go.Scatter(x=np.array([5., 5., 22.5, 22.5]),
                           y=np.array([0., 3.60555128, 3.60555128,
                                       2.23606798]),
                           marker=go.Marker(color='rgb(0,116,217)'),
                           mode='lines',
                           xaxis='x',
                           yaxis='y')
            ]),
            layout=go.Layout(autosize=False,
                             height='100%',
                             hovermode='closest',
                             showlegend=False,
                             width='100%',
                             xaxis=go.XAxis(mirror='allticks',
                                            rangemode='tozero',
                                            showgrid=False,
                                            showline=True,
                                            showticklabels=True,
                                            tickmode='array',
                                            ticks='outside',
                                            ticktext=np.array(
                                                ['3', '2', '0', '1']),
                                            tickvals=[5.0, 15.0, 25.0, 35.0],
                                            type='linear',
                                            zeroline=False),
                             yaxis=go.YAxis(mirror='allticks',
                                            rangemode='tozero',
                                            showgrid=False,
                                            showline=True,
                                            showticklabels=True,
                                            ticks='outside',
                                            type='linear',
                                            zeroline=False)))

        self.assertEqual(len(dendro['data']), 3)

        # this is actually a bit clearer when debugging tests.
        self.assert_dict_equal(dendro['data'][0], expected_dendro['data'][0])
        self.assert_dict_equal(dendro['data'][1], expected_dendro['data'][1])
        self.assert_dict_equal(dendro['data'][2], expected_dendro['data'][2])

        self.assert_dict_equal(dendro['layout'], expected_dendro['layout'])
예제 #2
0
def create_ohlc(open, high, low, close, dates=None, direction="both", **kwargs):
    """
    **deprecated**, use instead the plotly.graph_objects trace 
    :class:`plotly.graph_objects.Ohlc`

    :param (list) open: opening values
    :param (list) high: high values
    :param (list) low: low values
    :param (list) close: closing
    :param (list) dates: list of datetime objects. Default: None
    :param (string) direction: direction can be 'increasing', 'decreasing',
        or 'both'. When the direction is 'increasing', the returned figure
        consists of all units where the close value is greater than the
        corresponding open value, and when the direction is 'decreasing',
        the returned figure consists of all units where the close value is
        less than or equal to the corresponding open value. When the
        direction is 'both', both increasing and decreasing units are
        returned. Default: 'both'
    :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
        These kwargs describe other attributes about the ohlc Scatter trace
        such as the color or the legend name. For more information on valid
        kwargs call help(plotly.graph_objs.Scatter)

    :rtype (dict): returns a representation of an ohlc chart figure.

    Example 1: Simple OHLC chart from a Pandas DataFrame

    >>> from plotly.figure_factory import create_ohlc
    >>> from datetime import datetime

    >>> import pandas as pd
    >>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
    >>> fig = create_ohlc(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'], dates=df.index)
    >>> fig.show()
    """
    if dates is not None:
        utils.validate_equal_length(open, high, low, close, dates)
    else:
        utils.validate_equal_length(open, high, low, close)
    validate_ohlc(open, high, low, close, direction, **kwargs)

    if direction == "increasing":
        ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs)
        data = [ohlc_incr]
    elif direction == "decreasing":
        ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs)
        data = [ohlc_decr]
    else:
        ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, **kwargs)
        ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, **kwargs)
        data = [ohlc_incr, ohlc_decr]

    layout = graph_objs.Layout(xaxis=dict(zeroline=False), hovermode="closest")

    return graph_objs.Figure(data=data, layout=layout)
예제 #3
0
def _iplot3d_surface(array, colorscale):
    data = [go.Surface(z=array, colorscale=colorscale)]

    layout = go.Layout(autosize=False,
                       width=600,
                       height=600,
                       margin=dict(l=65, r=50, b=65, t=90),
                       scene=dict(aspectmode="data"))
    offline.init_notebook_mode(connected=True)
    fig = go.Figure(data=data, layout=layout)
    offline.iplot(fig)
예제 #4
0
def _iplot3d(las, max_points, point_size, dim, colorscale):
    """
    Plots the 3d point cloud in a compatible version for Jupyter notebooks.
    :return:
    """
    # Check if in iPython notebook
    try:
        cfg = get_ipython().config
        if 'jupyter' in cfg['IPKernelApp']['connection_file']:
            if las.header.count > max_points:
                print(
                    "Point cloud too large, down sampling for plot performance."
                )
                rand = np.random.randint(0, las.count, 30000)
                x = las.points.x.iloc[rand]
                y = las.points.y.iloc[rand]
                z = las.points.z.iloc[rand]
                color_var = las.points[dim].values[rand]

                trace1 = go.Scatter3d(x=x,
                                      y=y,
                                      z=z,
                                      mode='markers',
                                      marker=dict(size=point_size,
                                                  color=color_var,
                                                  colorscale=colorscale,
                                                  opacity=1))

                data = [trace1]
                layout = go.Layout(margin=dict(l=0, r=0, b=0, t=0),
                                   scene=dict(aspectmode="data"))
                offline.init_notebook_mode(connected=True)
                fig = go.Figure(data=data, layout=layout)
                offline.iplot(fig)
        else:
            print("This function can only be used within a Jupyter notebook.")
            return (False)
    except NameError:
        return (False)
예제 #5
0
def create_candlestick(open,
                       high,
                       low,
                       close,
                       dates=None,
                       direction="both",
                       **kwargs):
    """
    BETA function that creates a candlestick chart

    :param (list) open: opening values
    :param (list) high: high values
    :param (list) low: low values
    :param (list) close: closing values
    :param (list) dates: list of datetime objects. Default: None
    :param (string) direction: direction can be 'increasing', 'decreasing',
        or 'both'. When the direction is 'increasing', the returned figure
        consists of all candlesticks where the close value is greater than
        the corresponding open value, and when the direction is
        'decreasing', the returned figure consists of all candlesticks
        where the close value is less than or equal to the corresponding
        open value. When the direction is 'both', both increasing and
        decreasing candlesticks are returned. Default: 'both'
    :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
        These kwargs describe other attributes about the ohlc Scatter trace
        such as the color or the legend name. For more information on valid
        kwargs call help(plotly.graph_objs.Scatter)

    :rtype (dict): returns a representation of candlestick chart figure.

    Example 1: Simple candlestick chart from a Pandas DataFrame

    >>> from plotly.figure_factory import create_candlestick
    >>> from datetime import datetime

    >>> import pandas.io.data as web

    >>> df = web.DataReader("aapl", 'yahoo', datetime(2007, 10, 1), datetime(2009, 4, 1))
    >>> fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index)
    >>> fig.show()

    Example 2: Add text and annotations to the candlestick chart
    
    >>> fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index)
    >>> # Update the fig - all options here: https://plot.ly/python/reference/#Layout
    >>> fig['layout'].update({
        'title': 'The Great Recession',
        'yaxis': {'title': 'AAPL Stock'},
        'shapes': [{
            'x0': '2007-12-01', 'x1': '2007-12-01',
            'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper',
            'line': {'color': 'rgb(30,30,30)', 'width': 1}
        }],
        'annotations': [{
            'x': '2007-12-01', 'y': 0.05, 'xref': 'x', 'yref': 'paper',
            'showarrow': False, 'xanchor': 'left',
            'text': 'Official start of the recession'
        }]
    })
    >>> fig.show()

    Example 3: Customize the candlestick colors
    
    >>> from plotly.figure_factory import create_candlestick
    >>> from plotly.graph_objs import Line, Marker
    >>> from datetime import datetime

    >>> import pandas.io.data as web

    >>> df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1))

    >>> # Make increasing candlesticks and customize their color and name
    >>> fig_increasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index,
    ...     direction='increasing', name='AAPL',
    ...     marker=Marker(color='rgb(150, 200, 250)'),
    ...     line=Line(color='rgb(150, 200, 250)'))

    >>> # Make decreasing candlesticks and customize their color and name
    >>> fig_decreasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index,
    ...     direction='decreasing',
    ...     marker=Marker(color='rgb(128, 128, 128)'),
    ...     line=Line(color='rgb(128, 128, 128)'))

    >>> # Initialize the figure
    >>> fig = fig_increasing

    >>> # Add decreasing data with .extend()
    >>> fig['data'].extend(fig_decreasing['data'])
    >>> fig.show()

    Example 4: Candlestick chart with datetime objects
    
    >>> from plotly.figure_factory import create_candlestick

    >>> from datetime import datetime

    >>> # Add data
    >>> open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
    >>> high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
    >>> low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
    >>> close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
    >>> dates = [datetime(year=2013, month=10, day=10),
    ...          datetime(year=2013, month=11, day=10),
    ...          datetime(year=2013, month=12, day=10),
    ...          datetime(year=2014, month=1, day=10),
    ...          datetime(year=2014, month=2, day=10)]

    >>> # Create ohlc
    >>> fig = create_candlestick(open_data, high_data,
    ...     low_data, close_data, dates=dates)
    >>> fig.show()

    """
    if dates is not None:
        utils.validate_equal_length(open, high, low, close, dates)
    else:
        utils.validate_equal_length(open, high, low, close)
    validate_ohlc(open, high, low, close, direction, **kwargs)

    if direction is "increasing":
        candle_incr_data = make_increasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        data = candle_incr_data
    elif direction is "decreasing":
        candle_decr_data = make_decreasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        data = candle_decr_data
    else:
        candle_incr_data = make_increasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        candle_decr_data = make_decreasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        data = candle_incr_data + candle_decr_data

    layout = graph_objs.Layout()
    return graph_objs.Figure(data=data, layout=layout)
def create_ternary_contour(coordinates, values, pole_labels=['a', 'b', 'c'],
                           tooltip_mode='proportions', width=500, height=500,
                           ncontours=None,
                           showscale=False,
                           coloring=None,
                           colorscale=None,
                           plot_bgcolor='rgb(240,240,240)',
                           title=None):
    """
    Ternary contour plot.

    Parameters
    ----------

    coordinates : list or ndarray
        Barycentric coordinates of shape (2, N) or (3, N) where N is the
        number of data points. The sum of the 3 coordinates is expected
        to be 1 for all data points.
    values : array-like
        Data points of field to be represented as contours.
    pole_labels : str, default ['a', 'b', 'c']
        Names of the three poles of the triangle.
    tooltip_mode : str, 'proportions' or 'percents'
        Coordinates inside the ternary plot can be displayed either as
        proportions (adding up to 1) or as percents (adding up to 100).
    width : int
        Figure width.
    height : int
        Figure height.
    ncontours : int or None
        Number of contours to display (determined automatically if None).
    showscale : bool, default False
        If True, a colorbar showing the color scale is displayed.
    coloring : None or 'lines'
        How to display contour. Filled contours if None, lines if ``lines``.
    colorscale : None or array-like
        colorscale of the contours.
    plot_bgcolor :
        color of figure background
    title : str or None
        Title of ternary plot

    Examples
    ========

    Example 1: ternary contour plot with filled contours

    # Define coordinates
    a, b = np.mgrid[0:1:20j, 0:1:20j]
    mask = a + b <= 1
    a = a[mask].ravel()
    b = b[mask].ravel()
    c = 1 - a - b
    # Values to be displayed as contours
    z = a * b * c
    fig = ff.create_ternary_contour(np.stack((a, b, c)), z)

    It is also possible to give only two barycentric coordinates for each
    point, since the sum of the three coordinates is one:

    fig = ff.create_ternary_contour(np.stack((a, b)), z)

    Example 2: ternary contour plot with line contours

    fig = ff.create_ternary_contour(np.stack((a, b)), z, coloring='lines')
    """
    if interpolate is None:
        raise ImportError("""\
The create_ternary_contour figure factory requires the scipy package""")

    grid_z, gr_x, gr_y, tooltip = _compute_grid(coordinates, values,
                                                tooltip_mode)

    x_ticks, y_ticks, posx, posy = _cart_coord_ticks(t=0.01)

    layout = _ternary_layout(pole_labels=pole_labels,
                             width=width, height=height, title=title,
                             plot_bgcolor=plot_bgcolor)

    annotations = _set_ticklabels(layout['annotations'], posx, posy,
                                  proportions=True)
    if colorscale is None:
        colorscale = _pl_deep()

    contour_trace = _contour_trace(gr_x, gr_y, grid_z, tooltip,
                                   ncontours=ncontours,
                                   showscale=showscale,
                                   colorscale=colorscale,
                                   coloring=coloring)
    side_trace, tick_trace = _styling_traces_ternary(x_ticks, y_ticks)
    fig = go.Figure(data=[contour_trace,  tick_trace, side_trace],
                    layout=layout)
    fig.layout.annotations = annotations
    return fig
예제 #7
0
def create_2d_density(x, y, colorscale='Earth', ncontours=20,
                      hist_color=(0, 0, 0.5), point_color=(0, 0, 0.5),
                      point_size=2, title='2D Density Plot',
                      height=600, width=600):
    """
    Returns figure for a 2D density plot

    :param (list|array) x: x-axis data for plot generation
    :param (list|array) y: y-axis data for plot generation
    :param (str|tuple|list) colorscale: either a plotly scale name, an rgb
        or hex color, a color tuple or a list or tuple of colors. An rgb
        color is of the form 'rgb(x, y, z)' where x, y, z belong to the
        interval [0, 255] and a color tuple is a tuple of the form
        (a, b, c) where a, b and c belong to [0, 1]. If colormap is a
        list, it must contain the valid color types aforementioned as its
        members.
    :param (int) ncontours: the number of 2D contours to draw on the plot
    :param (str) hist_color: the color of the plotted histograms
    :param (str) point_color: the color of the scatter points
    :param (str) point_size: the color of the scatter points
    :param (str) title: set the title for the plot
    :param (float) height: the height of the chart
    :param (float) width: the width of the chart

    Example 1: Simple 2D Density Plot
    ```
    import plotly.plotly as py
    from plotly.figure_factory create_2d_density

    import numpy as np

    # Make data points
    t = np.linspace(-1,1.2,2000)
    x = (t**3)+(0.3*np.random.randn(2000))
    y = (t**6)+(0.3*np.random.randn(2000))

    # Create a figure
    fig = create_2D_density(x, y)

    # Plot the data
    py.iplot(fig, filename='simple-2d-density')
    ```

    Example 2: Using Parameters
    ```
    import plotly.plotly as py
    from plotly.figure_factory create_2d_density

    import numpy as np

    # Make data points
    t = np.linspace(-1,1.2,2000)
    x = (t**3)+(0.3*np.random.randn(2000))
    y = (t**6)+(0.3*np.random.randn(2000))

    # Create custom colorscale
    colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)',
                  (1, 1, 0.2), (0.98,0.98,0.98)]

    # Create a figure
    fig = create_2D_density(
        x, y, colorscale=colorscale,
        hist_color='rgb(255, 237, 222)', point_size=3)

    # Plot the data
    py.iplot(fig, filename='use-parameters')
    ```
    """

    # validate x and y are filled with numbers only
    for array in [x, y]:
        if not all(isinstance(element, Number) for element in array):
            raise exceptions.PlotlyError(
                "All elements of your 'x' and 'y' lists must be numbers."
            )

    # validate x and y are the same length
    if len(x) != len(y):
        raise exceptions.PlotlyError(
            "Both lists 'x' and 'y' must be the same length."
        )

    colorscale = utils.validate_colors(colorscale, 'rgb')
    colorscale = make_linear_colorscale(colorscale)

    # validate hist_color and point_color
    hist_color = utils.validate_colors(hist_color, 'rgb')
    point_color = utils.validate_colors(point_color, 'rgb')

    trace1 = graph_objs.Scatter(
        x=x, y=y, mode='markers', name='points',
        marker=dict(
            color=point_color[0],
            size=point_size,
            opacity=0.4
        )
    )
    trace2 = graph_objs.Histogram2dContour(
        x=x, y=y, name='density', ncontours=ncontours,
        colorscale=colorscale, reversescale=True, showscale=False
    )
    trace3 = graph_objs.Histogram(
        x=x, name='x density',
        marker=dict(color=hist_color[0]), yaxis='y2'
    )
    trace4 = graph_objs.Histogram(
        y=y, name='y density',
        marker=dict(color=hist_color[0]), xaxis='x2'
    )
    data = [trace1, trace2, trace3, trace4]

    layout = graph_objs.Layout(
        showlegend=False,
        autosize=False,
        title=title,
        height=height,
        width=width,
        xaxis=dict(
            domain=[0, 0.85],
            showgrid=False,
            zeroline=False
        ),
        yaxis=dict(
            domain=[0, 0.85],
            showgrid=False,
            zeroline=False
        ),
        margin=dict(
            t=50
        ),
        hovermode='closest',
        bargap=0,
        xaxis2=dict(
            domain=[0.85, 1],
            showgrid=False,
            zeroline=False
        ),
        yaxis2=dict(
            domain=[0.85, 1],
            showgrid=False,
            zeroline=False
        )
    )

    fig = graph_objs.Figure(data=data, layout=layout)
    return fig
예제 #8
0
def create_annotated_heatmap(z,
                             x=None,
                             y=None,
                             annotation_text=None,
                             colorscale="Plasma",
                             font_colors=None,
                             showscale=False,
                             reversescale=False,
                             **kwargs):
    """
    BETA function that creates annotated heatmaps

    This function adds annotations to each cell of the heatmap.

    :param (list[list]|ndarray) z: z matrix to create heatmap.
    :param (list) x: x axis labels.
    :param (list) y: y axis labels.
    :param (list[list]|ndarray) annotation_text: Text strings for
        annotations. Should have the same dimensions as the z matrix. If no
        text is added, the values of the z matrix are annotated. Default =
        z matrix values.
    :param (list|str) colorscale: heatmap colorscale.
    :param (list) font_colors: List of two color strings: [min_text_color,
        max_text_color] where min_text_color is applied to annotations for
        heatmap values < (max_value - min_value)/2. If font_colors is not
        defined, the colors are defined logically as black or white
        depending on the heatmap's colorscale.
    :param (bool) showscale: Display colorscale. Default = False
    :param (bool) reversescale: Reverse colorscale. Default = False
    :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
        These kwargs describe other attributes about the annotated Heatmap
        trace such as the colorscale. For more information on valid kwargs
        call help(plotly.graph_objs.Heatmap)

    Example 1: Simple annotated heatmap with default configuration
    ```
    import plotly.plotly as py
    import plotly.figure_factory as FF

    z = [[0.300000, 0.00000, 0.65, 0.300000],
         [1, 0.100005, 0.45, 0.4300],
         [0.300000, 0.00000, 0.65, 0.300000],
         [1, 0.100005, 0.45, 0.00000]]

    figure = FF.create_annotated_heatmap(z)
    py.iplot(figure)
    ```
    """

    # Avoiding mutables in the call signature
    font_colors = font_colors if font_colors is not None else []
    validate_annotated_heatmap(z, x, y, annotation_text)

    # validate colorscale
    colorscale_validator = ColorscaleValidator()
    colorscale = colorscale_validator.validate_coerce(colorscale)

    annotations = _AnnotatedHeatmap(z, x, y, annotation_text, colorscale,
                                    font_colors, reversescale,
                                    **kwargs).make_annotations()

    if x or y:
        trace = dict(type="heatmap",
                     z=z,
                     x=x,
                     y=y,
                     colorscale=colorscale,
                     showscale=showscale,
                     reversescale=reversescale,
                     **kwargs)
        layout = dict(
            annotations=annotations,
            xaxis=dict(ticks="", dtick=1, side="top",
                       gridcolor="rgb(0, 0, 0)"),
            yaxis=dict(ticks="", dtick=1, ticksuffix="  "),
        )
    else:
        trace = dict(type="heatmap",
                     z=z,
                     colorscale=colorscale,
                     showscale=showscale,
                     reversescale=reversescale,
                     **kwargs)
        layout = dict(
            annotations=annotations,
            xaxis=dict(ticks="",
                       side="top",
                       gridcolor="rgb(0, 0, 0)",
                       showticklabels=False),
            yaxis=dict(ticks="", ticksuffix="  ", showticklabels=False),
        )

    data = [trace]

    return graph_objs.Figure(data=data, layout=layout)
예제 #9
0
def create_candlestick(open,
                       high,
                       low,
                       close,
                       dates=None,
                       direction="both",
                       **kwargs):
    """
    **deprecated**, use instead the plotly.graph_objects trace
    :class:`plotly.graph_objects.Candlestick`

    :param (list) open: opening values
    :param (list) high: high values
    :param (list) low: low values
    :param (list) close: closing values
    :param (list) dates: list of datetime objects. Default: None
    :param (string) direction: direction can be 'increasing', 'decreasing',
        or 'both'. When the direction is 'increasing', the returned figure
        consists of all candlesticks where the close value is greater than
        the corresponding open value, and when the direction is
        'decreasing', the returned figure consists of all candlesticks
        where the close value is less than or equal to the corresponding
        open value. When the direction is 'both', both increasing and
        decreasing candlesticks are returned. Default: 'both'
    :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
        These kwargs describe other attributes about the ohlc Scatter trace
        such as the color or the legend name. For more information on valid
        kwargs call help(plotly.graph_objs.Scatter)

    :rtype (dict): returns a representation of candlestick chart figure.

    Example 1: Simple candlestick chart from a Pandas DataFrame

    >>> from plotly.figure_factory import create_candlestick
    >>> from datetime import datetime
    >>> import pandas as pd

    >>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
    >>> fig = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
    ...                          dates=df.index)
    >>> fig.show()

    Example 2: Customize the candlestick colors

    >>> from plotly.figure_factory import create_candlestick
    >>> from plotly.graph_objs import Line, Marker
    >>> from datetime import datetime

    >>> import pandas as pd
    >>> df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')

    >>> # Make increasing candlesticks and customize their color and name
    >>> fig_increasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
    ...     dates=df.index,
    ...     direction='increasing', name='AAPL',
    ...     marker=Marker(color='rgb(150, 200, 250)'),
    ...     line=Line(color='rgb(150, 200, 250)'))

    >>> # Make decreasing candlesticks and customize their color and name
    >>> fig_decreasing = create_candlestick(df['AAPL.Open'], df['AAPL.High'], df['AAPL.Low'], df['AAPL.Close'],
    ...     dates=df.index,
    ...     direction='decreasing',
    ...     marker=Marker(color='rgb(128, 128, 128)'),
    ...     line=Line(color='rgb(128, 128, 128)'))

    >>> # Initialize the figure
    >>> fig = fig_increasing

    >>> # Add decreasing data with .extend()
    >>> fig.add_trace(fig_decreasing['data']) # doctest: +SKIP
    >>> fig.show()

    Example 3: Candlestick chart with datetime objects

    >>> from plotly.figure_factory import create_candlestick

    >>> from datetime import datetime

    >>> # Add data
    >>> open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
    >>> high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
    >>> low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
    >>> close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
    >>> dates = [datetime(year=2013, month=10, day=10),
    ...          datetime(year=2013, month=11, day=10),
    ...          datetime(year=2013, month=12, day=10),
    ...          datetime(year=2014, month=1, day=10),
    ...          datetime(year=2014, month=2, day=10)]

    >>> # Create ohlc
    >>> fig = create_candlestick(open_data, high_data,
    ...     low_data, close_data, dates=dates)
    >>> fig.show()
    """
    if dates is not None:
        utils.validate_equal_length(open, high, low, close, dates)
    else:
        utils.validate_equal_length(open, high, low, close)
    validate_ohlc(open, high, low, close, direction, **kwargs)

    if direction == "increasing":
        candle_incr_data = make_increasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        data = candle_incr_data
    elif direction == "decreasing":
        candle_decr_data = make_decreasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        data = candle_decr_data
    else:
        candle_incr_data = make_increasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        candle_decr_data = make_decreasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        data = candle_incr_data + candle_decr_data

    layout = graph_objs.Layout()
    return graph_objs.Figure(data=data, layout=layout)
예제 #10
0
def create_streamline(x,
                      y,
                      u,
                      v,
                      density=1,
                      angle=math.pi / 9,
                      arrow_scale=.09,
                      **kwargs):
    """
    Returns data for a streamline plot.

    :param (list|ndarray) x: 1 dimensional, evenly spaced list or array
    :param (list|ndarray) y: 1 dimensional, evenly spaced list or array
    :param (ndarray) u: 2 dimensional array
    :param (ndarray) v: 2 dimensional array
    :param (float|int) density: controls the density of streamlines in
        plot. This is multiplied by 30 to scale similiarly to other
        available streamline functions such as matplotlib.
        Default = 1
    :param (angle in radians) angle: angle of arrowhead. Default = pi/9
    :param (float in [0,1]) arrow_scale: value to scale length of arrowhead
        Default = .09
    :param kwargs: kwargs passed through plotly.graph_objs.Scatter
        for more information on valid kwargs call
        help(plotly.graph_objs.Scatter)

    :rtype (dict): returns a representation of streamline figure.

    Example 1: Plot simple streamline and increase arrow size
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_streamline

    import numpy as np
    import math

    # Add data
    x = np.linspace(-3, 3, 100)
    y = np.linspace(-3, 3, 100)
    Y, X = np.meshgrid(x, y)
    u = -1 - X**2 + Y
    v = 1 + X - Y**2
    u = u.T  # Transpose
    v = v.T  # Transpose

    # Create streamline
    fig = create_streamline(x, y, u, v, arrow_scale=.1)

    # Plot
    py.plot(fig, filename='streamline')
    ```

    Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_streamline

    import numpy as np
    import math

    # Add data
    N = 50
    x_start, x_end = -2.0, 2.0
    y_start, y_end = -1.0, 1.0
    x = np.linspace(x_start, x_end, N)
    y = np.linspace(y_start, y_end, N)
    X, Y = np.meshgrid(x, y)
    ss = 5.0
    x_s, y_s = -1.0, 0.0

    # Compute the velocity field on the mesh grid
    u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2)
    v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2)

    # Create streamline
    fig = create_streamline(x, y, u_s, v_s, density=2, name='streamline')

    # Add source point
    point = Scatter(x=[x_s], y=[y_s], mode='markers',
                    marker=Marker(size=14), name='source point')

    # Plot
    fig['data'].append(point)
    py.plot(fig, filename='streamline')
    ```
    """
    utils.validate_equal_length(x, y)
    utils.validate_equal_length(u, v)
    validate_streamline(x, y)
    utils.validate_positive_scalars(density=density, arrow_scale=arrow_scale)

    streamline_x, streamline_y = _Streamline(x, y, u, v, density, angle,
                                             arrow_scale).sum_streamlines()
    arrow_x, arrow_y = _Streamline(x, y, u, v, density, angle,
                                   arrow_scale).get_streamline_arrows()

    streamline = graph_objs.Scatter(x=streamline_x + arrow_x,
                                    y=streamline_y + arrow_y,
                                    mode='lines',
                                    **kwargs)

    data = [streamline]
    layout = graph_objs.Layout(hovermode='closest')

    return graph_objs.Figure(data=data, layout=layout)
예제 #11
0
    def test_dendrogram_colorscale(self):
        X = np.array([[1, 2, 3, 4], [1, 1, 3, 4], [1, 2, 1, 4], [1, 2, 3, 1]])
        greyscale = [
            'rgb(0,0,0)',  # black
            'rgb(05,105,105)',  # dim grey
            'rgb(128,128,128)',  # grey
            'rgb(169,169,169)',  # dark grey
            'rgb(192,192,192)',  # silver
            'rgb(211,211,211)',  # light grey
            'rgb(220,220,220)',  # gainsboro
            'rgb(245,245,245)'
        ]  # white smoke

        dendro = tls.FigureFactory.create_dendrogram(X, colorscale=greyscale)

        expected_dendro = go.Figure(
            data=go.Data([
                go.Scatter(x=np.array([25., 25., 35., 35.]),
                           y=np.array([0., 1., 1., 0.]),
                           marker=go.Marker(color='rgb(128,128,128)'),
                           mode='lines',
                           xaxis='x',
                           yaxis='y'),
                go.Scatter(x=np.array([15., 15., 30., 30.]),
                           y=np.array([0., 2.23606798, 2.23606798, 1.]),
                           marker=go.Marker(color='rgb(128,128,128)'),
                           mode='lines',
                           xaxis='x',
                           yaxis='y'),
                go.Scatter(x=np.array([5., 5., 22.5, 22.5]),
                           y=np.array([0., 3.60555128, 3.60555128,
                                       2.23606798]),
                           marker=go.Marker(color='rgb(0,0,0)'),
                           mode='lines',
                           xaxis='x',
                           yaxis='y')
            ]),
            layout=go.Layout(autosize=False,
                             height='100%',
                             hovermode='closest',
                             showlegend=False,
                             width='100%',
                             xaxis=go.XAxis(mirror='allticks',
                                            rangemode='tozero',
                                            showgrid=False,
                                            showline=True,
                                            showticklabels=True,
                                            tickmode='array',
                                            ticks='outside',
                                            ticktext=np.array(
                                                ['3', '2', '0', '1']),
                                            tickvals=[5.0, 15.0, 25.0, 35.0],
                                            type='linear',
                                            zeroline=False),
                             yaxis=go.YAxis(mirror='allticks',
                                            rangemode='tozero',
                                            showgrid=False,
                                            showline=True,
                                            showticklabels=True,
                                            ticks='outside',
                                            type='linear',
                                            zeroline=False)))

        self.assertEqual(len(dendro['data']), 3)

        # this is actually a bit clearer when debugging tests.
        self.assert_dict_equal(dendro['data'][0], expected_dendro['data'][0])
        self.assert_dict_equal(dendro['data'][1], expected_dendro['data'][1])
        self.assert_dict_equal(dendro['data'][2], expected_dendro['data'][2])
예제 #12
0
    def test_dendrogram_random_matrix(self):

        # create a random uncorrelated matrix
        X = np.random.rand(5, 5)

        # variable 2 is correlated with all the other variables
        X[2, :] = sum(X, 0)

        names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
        dendro = tls.FigureFactory.create_dendrogram(X, labels=names)

        expected_dendro = go.Figure(
            data=go.Data([
                go.Scatter(marker=go.Marker(color='rgb(61,153,112)'),
                           mode='lines',
                           xaxis='x',
                           yaxis='y'),
                go.Scatter(marker=go.Marker(color='rgb(61,153,112)'),
                           mode='lines',
                           xaxis='x',
                           yaxis='y'),
                go.Scatter(marker=go.Marker(color='rgb(61,153,112)'),
                           mode='lines',
                           xaxis='x',
                           yaxis='y'),
                go.Scatter(marker=go.Marker(color='rgb(0,116,217)'),
                           mode='lines',
                           xaxis='x',
                           yaxis='y')
            ]),
            layout=go.Layout(autosize=False,
                             height='100%',
                             hovermode='closest',
                             showlegend=False,
                             width='100%',
                             xaxis=go.XAxis(
                                 mirror='allticks',
                                 rangemode='tozero',
                                 showgrid=False,
                                 showline=True,
                                 showticklabels=True,
                                 tickmode='array',
                                 ticks='outside',
                                 tickvals=[5.0, 15.0, 25.0, 35.0, 45.0],
                                 type='linear',
                                 zeroline=False),
                             yaxis=go.YAxis(mirror='allticks',
                                            rangemode='tozero',
                                            showgrid=False,
                                            showline=True,
                                            showticklabels=True,
                                            ticks='outside',
                                            type='linear',
                                            zeroline=False)))

        self.assertEqual(len(dendro['data']), 4)

        # it's random, so we can only check that the values aren't equal
        y_vals = [
            dendro['data'][0].pop('y'), dendro['data'][1].pop('y'),
            dendro['data'][2].pop('y'), dendro['data'][3].pop('y')
        ]
        for i in range(len(y_vals)):
            for j in range(len(y_vals)):
                if i != j:
                    self.assertFalse(np.allclose(y_vals[i], y_vals[j]))

        x_vals = [
            dendro['data'][0].pop('x'), dendro['data'][1].pop('x'),
            dendro['data'][2].pop('x'), dendro['data'][3].pop('x')
        ]
        for i in range(len(x_vals)):
            for j in range(len(x_vals)):
                if i != j:
                    self.assertFalse(np.allclose(x_vals[i], x_vals[j]))

        # we also need to check the ticktext manually
        xaxis_ticktext = dendro['layout']['xaxis'].pop('ticktext')
        self.assertEqual(xaxis_ticktext[0], 'John')

        # this is actually a bit clearer when debugging tests.
        self.assert_dict_equal(dendro['data'][0], expected_dendro['data'][0])
        self.assert_dict_equal(dendro['data'][1], expected_dendro['data'][1])
        self.assert_dict_equal(dendro['data'][2], expected_dendro['data'][2])
        self.assert_dict_equal(dendro['data'][3], expected_dendro['data'][3])

        self.assert_dict_equal(dendro['layout'], expected_dendro['layout'])
def create_quiver(x,
                  y,
                  u,
                  v,
                  scale=.1,
                  arrow_scale=.3,
                  angle=math.pi / 9,
                  **kwargs):
    """
    Returns data for a quiver plot.

    :param (list|ndarray) x: x coordinates of the arrow locations
    :param (list|ndarray) y: y coordinates of the arrow locations
    :param (list|ndarray) u: x components of the arrow vectors
    :param (list|ndarray) v: y components of the arrow vectors
    :param (float in [0,1]) scale: scales size of the arrows(ideally to
        avoid overlap). Default = .1
    :param (float in [0,1]) arrow_scale: value multiplied to length of barb
        to get length of arrowhead. Default = .3
    :param (angle in radians) angle: angle of arrowhead. Default = pi/9
    :param kwargs: kwargs passed through plotly.graph_objs.Scatter
        for more information on valid kwargs call
        help(plotly.graph_objs.Scatter)

    :rtype (dict): returns a representation of quiver figure.

    Example 1: Trivial Quiver
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_quiver

    import math

    # 1 Arrow from (0,0) to (1,1)
    fig = create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1)

    py.plot(fig, filename='quiver')
    ```

    Example 2: Quiver plot using meshgrid
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_quiver

    import numpy as np
    import math

    # Add data
    x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2))
    u = np.cos(x)*y
    v = np.sin(x)*y

    #Create quiver
    fig = create_quiver(x, y, u, v)

    # Plot
    py.plot(fig, filename='quiver')
    ```

    Example 3: Styling the quiver plot
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_quiver
    import numpy as np
    import math

    # Add data
    x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5),
                       np.arange(-math.pi, math.pi, .5))
    u = np.cos(x)*y
    v = np.sin(x)*y

    # Create quiver
    fig = create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6,
                        name='Wind Velocity', line=Line(width=1))

    # Add title to layout
    fig['layout'].update(title='Quiver Plot')

    # Plot
    py.plot(fig, filename='quiver')
    ```
    """
    utils.validate_equal_length(x, y, u, v)
    utils.validate_positive_scalars(arrow_scale=arrow_scale, scale=scale)

    barb_x, barb_y = _Quiver(x, y, u, v, scale, arrow_scale, angle).get_barbs()
    arrow_x, arrow_y = _Quiver(x, y, u, v, scale, arrow_scale,
                               angle).get_quiver_arrows()
    quiver = graph_objs.Scatter(x=barb_x + arrow_x,
                                y=barb_y + arrow_y,
                                mode='lines',
                                **kwargs)

    data = [quiver]
    layout = graph_objs.Layout(hovermode='closest')

    return graph_objs.Figure(data=data, layout=layout)
예제 #14
0
def create_ohlc(open,
                high,
                low,
                close,
                dates=None,
                direction="both",
                **kwargs):
    """
    BETA function that creates an ohlc chart

    :param (list) open: opening values
    :param (list) high: high values
    :param (list) low: low values
    :param (list) close: closing
    :param (list) dates: list of datetime objects. Default: None
    :param (string) direction: direction can be 'increasing', 'decreasing',
        or 'both'. When the direction is 'increasing', the returned figure
        consists of all units where the close value is greater than the
        corresponding open value, and when the direction is 'decreasing',
        the returned figure consists of all units where the close value is
        less than or equal to the corresponding open value. When the
        direction is 'both', both increasing and decreasing units are
        returned. Default: 'both'
    :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
        These kwargs describe other attributes about the ohlc Scatter trace
        such as the color or the legend name. For more information on valid
        kwargs call help(plotly.graph_objs.Scatter)

    :rtype (dict): returns a representation of an ohlc chart figure.

    Example 1: Simple OHLC chart from a Pandas DataFrame

    >>> from plotly.figure_factory import create_ohlc
    >>> from datetime import datetime

    >>> import pandas.io.data as web

    >>> df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15),
    ...                     datetime(2008, 10, 15))
    >>> fig = create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index)
    >>> fig.show()

    Example 2: Add text and annotations to the OHLC chart

    >>> from plotly.figure_factory import create_ohlc
    >>> from datetime import datetime

    >>> import pandas.io.data as web

    >>> df = web.datareader("aapl", 'yahoo', datetime(2008, 8, 15),
    ...                     datetime(2008, 10, 15))
    >>> fig = create_ohlc(df.open, df.high, df.low, df.close, dates=df.index)

    >>> # update the fig - options here: https://plot.ly/python/reference/#layout
    >>> fig['layout'].update({
    ...     'title': 'the great recession',
    ...     'yaxis': {'title': 'aapl stock'},
    ...     'shapes': [{
    ...         'x0': '2008-09-15', 'x1': '2008-09-15', 'type': 'line',
    ...         'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper',
    ...         'line': {'color': 'rgb(40,40,40)', 'width': 0.5}
    ...     }],
    ...     'annotations': [{
    ...         'text': "the fall of lehman brothers",
    ...         'x': '2008-09-15', 'y': 1.02,
    ...         'xref': 'x', 'yref': 'paper',
    ...         'showarrow': false, 'xanchor': 'left'
    ...     }]
    ... })
    >>> fig.show()

    Example 3: Customize the OHLC colors

    >>> from plotly.figure_factory import create_ohlc
    >>> from plotly.graph_objs import Line, Marker
    >>> from datetime import datetime

    >>> import pandas.io.data as web

    >>> df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1),
    ...                     datetime(2009, 4, 1))

    >>> # Make increasing ohlc sticks and customize their color and name
    >>> fig_increasing = create_ohlc(df.Open, df.High, df.Low, df.Close,
    ...                              dates=df.index, direction='increasing',
    ...                              name='AAPL',
    ...                              line=Line(color='rgb(150, 200, 250)'))

    >>> # Make decreasing ohlc sticks and customize their color and name
    >>> fig_decreasing = create_ohlc(df.Open, df.High, df.Low, df.Close,
    ...                              dates=df.index, direction='decreasing',
    ...                              line=Line(color='rgb(128, 128, 128)'))

    >>> # Initialize the figure
    >>> fig = fig_increasing

    >>> # Add decreasing data with .extend()
    >>> fig['data'].extend(fig_decreasing['data'])
    >>> fig.show()


    Example 4: OHLC chart with datetime objects

    >>> from plotly.figure_factory import create_ohlc

    >>> from datetime import datetime

    >>> # Add data
    >>> open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
    >>> high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
    >>> low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
    >>> close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
    >>> dates = [datetime(year=2013, month=10, day=10),
    ...          datetime(year=2013, month=11, day=10),
    ...          datetime(year=2013, month=12, day=10),
    ...          datetime(year=2014, month=1, day=10),
    ...          datetime(year=2014, month=2, day=10)]

    >>> # Create ohlc
    >>> fig = create_ohlc(open_data, high_data, low_data, close_data, dates=dates)
    >>> fig.show()
    """
    if dates is not None:
        utils.validate_equal_length(open, high, low, close, dates)
    else:
        utils.validate_equal_length(open, high, low, close)
    validate_ohlc(open, high, low, close, direction, **kwargs)

    if direction is "increasing":
        ohlc_incr = make_increasing_ohlc(open, high, low, close, dates,
                                         **kwargs)
        data = [ohlc_incr]
    elif direction is "decreasing":
        ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates,
                                         **kwargs)
        data = [ohlc_decr]
    else:
        ohlc_incr = make_increasing_ohlc(open, high, low, close, dates,
                                         **kwargs)
        ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates,
                                         **kwargs)
        data = [ohlc_incr, ohlc_decr]

    layout = graph_objs.Layout(xaxis=dict(zeroline=False), hovermode="closest")

    return graph_objs.Figure(data=data, layout=layout)
예제 #15
0
def get_subplots(rows=1, columns=1, print_grid=False, **kwargs):
    """Return a dictionary instance with the subplots set in 'layout'.

    Example 1:
    # stack two subplots vertically
    fig = tools.get_subplots(rows=2)
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x1', yaxis='y1')]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')]

    Example 2:
    # print out string showing the subplot grid you've put in the layout
    fig = tools.get_subplots(rows=3, columns=2, print_grid=True)

    Keywords arguments with constant defaults:

    rows (kwarg, int greater than 0, default=1):
        Number of rows, evenly spaced vertically on the figure.

    columns (kwarg, int greater than 0, default=1):
        Number of columns, evenly spaced horizontally on the figure.

    horizontal_spacing (kwarg, float in [0,1], default=0.1):
        Space between subplot columns. Applied to all columns.

    vertical_spacing (kwarg, float in [0,1], default=0.05):
        Space between subplot rows. Applied to all rows.

    print_grid (kwarg, True | False, default=False):
        If True, prints a tab-delimited string representation
        of your plot grid.

    Keyword arguments with variable defaults:

    horizontal_spacing (kwarg, float in [0,1], default=0.2 / columns):
        Space between subplot columns.

    vertical_spacing (kwarg, float in [0,1], default=0.3 / rows):
        Space between subplot rows.

    """
    # TODO: protected until #282
    from plotly.graph_objs import graph_objs

    warnings.warn(
        "tools.get_subplots is depreciated. "
        "Please use tools.make_subplots instead."
    )

    # Throw exception for non-integer rows and columns
    if not isinstance(rows, int) or rows <= 0:
        raise Exception("Keyword argument 'rows' "
                        "must be an int greater than 0")
    if not isinstance(columns, int) or columns <= 0:
        raise Exception("Keyword argument 'columns' "
                        "must be an int greater than 0")

    # Throw exception if non-valid kwarg is sent
    VALID_KWARGS = ['horizontal_spacing', 'vertical_spacing']
    for key in kwargs.keys():
        if key not in VALID_KWARGS:
            raise Exception("Invalid keyword argument: '{0}'".format(key))

    # Set 'horizontal_spacing' / 'vertical_spacing' w.r.t. rows / columns
    try:
        horizontal_spacing = float(kwargs['horizontal_spacing'])
    except KeyError:
        horizontal_spacing = 0.2 / columns
    try:
        vertical_spacing = float(kwargs['vertical_spacing'])
    except KeyError:
        vertical_spacing = 0.3 / rows

    fig = dict(layout=graph_objs.Layout())  # will return this at the end
    plot_width = (1 - horizontal_spacing * (columns - 1)) / columns
    plot_height = (1 - vertical_spacing * (rows - 1)) / rows
    plot_num = 0
    for rrr in range(rows):
        for ccc in range(columns):
            xaxis_name = 'xaxis{0}'.format(plot_num + 1)
            x_anchor = 'y{0}'.format(plot_num + 1)
            x_start = (plot_width + horizontal_spacing) * ccc
            x_end = x_start + plot_width

            yaxis_name = 'yaxis{0}'.format(plot_num + 1)
            y_anchor = 'x{0}'.format(plot_num + 1)
            y_start = (plot_height + vertical_spacing) * rrr
            y_end = y_start + plot_height

            xaxis = dict(domain=[x_start, x_end], anchor=x_anchor)
            fig['layout'][xaxis_name] = xaxis
            yaxis = dict(domain=[y_start, y_end], anchor=y_anchor)
            fig['layout'][yaxis_name] = yaxis
            plot_num += 1

    if print_grid:
        print("This is the format of your plot grid!")
        grid_string = ""
        plot = 1
        for rrr in range(rows):
            grid_line = ""
            for ccc in range(columns):
                grid_line += "[{0}]\t".format(plot)
                plot += 1
            grid_string = grid_line + '\n' + grid_string
        print(grid_string)

    return graph_objs.Figure(fig)  # forces us to validate what we just did...
예제 #16
0
def create_distplot(
    hist_data,
    group_labels,
    bin_size=1.0,
    curve_type="kde",
    colors=None,
    rug_text=None,
    histnorm=DEFAULT_HISTNORM,
    show_hist=True,
    show_curve=True,
    show_rug=True,
):
    """
    Function that creates a distplot similar to seaborn.distplot;
    **this function is deprecated**, use instead :mod:`plotly.express`
    functions, for example

    >>> import plotly.express as px
    >>> tips = px.data.tips()
    >>> fig = px.histogram(tips, x="total_bill", y="tip", color="sex", marginal="rug",
    ...                    hover_data=tips.columns)
    >>> fig.show()


    The distplot can be composed of all or any combination of the following
    3 components: (1) histogram, (2) curve: (a) kernel density estimation
    or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
    (from multiple datasets) can be created in the same plot.

    :param (list[list]) hist_data: Use list of lists to plot multiple data
        sets on the same plot.
    :param (list[str]) group_labels: Names for each data set.
    :param (list[float]|float) bin_size: Size of histogram bins.
        Default = 1.
    :param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
    :param (str) histnorm: 'probability density' or 'probability'
        Default = 'probability density'
    :param (bool) show_hist: Add histogram to distplot? Default = True
    :param (bool) show_curve: Add curve to distplot? Default = True
    :param (bool) show_rug: Add rug to distplot? Default = True
    :param (list[str]) colors: Colors for traces.
    :param (list[list]) rug_text: Hovertext values for rug_plot,
    :return (dict): Representation of a distplot figure.

    Example 1: Simple distplot of 1 data set

    >>> from plotly.figure_factory import create_distplot

    >>> hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
    ...               3.5, 4.1, 4.4, 4.5, 4.5,
    ...               5.0, 5.0, 5.2, 5.5, 5.5,
    ...               5.5, 5.5, 5.5, 6.1, 7.0]]
    >>> group_labels = ['distplot example']
    >>> fig = create_distplot(hist_data, group_labels)
    >>> fig.show()


    Example 2: Two data sets and added rug text
    
    >>> from plotly.figure_factory import create_distplot
    >>> # Add histogram data
    >>> hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
    ...            -0.9, -0.07, 1.95, 0.9, -0.2,
    ...            -0.5, 0.3, 0.4, -0.37, 0.6]
    >>> hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
    ...            1.0, 0.8, 1.7, 0.5, 0.8,
    ...            -0.3, 1.2, 0.56, 0.3, 2.2]

    >>> # Group data together
    >>> hist_data = [hist1_x, hist2_x]

    >>> group_labels = ['2012', '2013']

    >>> # Add text
    >>> rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
    ...       'f1', 'g1', 'h1', 'i1', 'j1',
    ...       'k1', 'l1', 'm1', 'n1', 'o1']

    >>> rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
    ...       'f2', 'g2', 'h2', 'i2', 'j2',
    ...       'k2', 'l2', 'm2', 'n2', 'o2']

    >>> # Group text together
    >>> rug_text_all = [rug_text_1, rug_text_2]

    >>> # Create distplot
    >>> fig = create_distplot(
    ...     hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)

    >>> # Add title
    >>> fig.update_layout(title='Dist Plot') # doctest: +SKIP
    >>> fig.show()


    Example 3: Plot with normal curve and hide rug plot
    
    >>> from plotly.figure_factory import create_distplot
    >>> import numpy as np

    >>> x1 = np.random.randn(190)
    >>> x2 = np.random.randn(200)+1
    >>> x3 = np.random.randn(200)-1
    >>> x4 = np.random.randn(210)+2

    >>> hist_data = [x1, x2, x3, x4]
    >>> group_labels = ['2012', '2013', '2014', '2015']

    >>> fig = create_distplot(
    ...     hist_data, group_labels, curve_type='normal',
    ...     show_rug=False, bin_size=.4)


    Example 4: Distplot with Pandas
    
    >>> from plotly.figure_factory import create_distplot
    >>> import numpy as np
    >>> import pandas as pd

    >>> df = pd.DataFrame({'2012': np.random.randn(200),
    ...                    '2013': np.random.randn(200)+1})
    >>> fig = create_distplot([df[c] for c in df.columns], df.columns)
    >>> fig.show()
    """
    if colors is None:
        colors = []
    if rug_text is None:
        rug_text = []

    validate_distplot(hist_data, curve_type)
    utils.validate_equal_length(hist_data, group_labels)

    if isinstance(bin_size, (float, int)):
        bin_size = [bin_size] * len(hist_data)

    data = []
    if show_hist:

        hist = _Distplot(
            hist_data,
            histnorm,
            group_labels,
            bin_size,
            curve_type,
            colors,
            rug_text,
            show_hist,
            show_curve,
        ).make_hist()

        data.append(hist)

    if show_curve:

        if curve_type == "normal":
            curve = _Distplot(
                hist_data,
                histnorm,
                group_labels,
                bin_size,
                curve_type,
                colors,
                rug_text,
                show_hist,
                show_curve,
            ).make_normal()
        else:
            curve = _Distplot(
                hist_data,
                histnorm,
                group_labels,
                bin_size,
                curve_type,
                colors,
                rug_text,
                show_hist,
                show_curve,
            ).make_kde()

        data.append(curve)

    if show_rug:

        rug = _Distplot(
            hist_data,
            histnorm,
            group_labels,
            bin_size,
            curve_type,
            colors,
            rug_text,
            show_hist,
            show_curve,
        ).make_rug()

        data.append(rug)
        layout = graph_objs.Layout(
            barmode="overlay",
            hovermode="closest",
            legend=dict(traceorder="reversed"),
            xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
            yaxis1=dict(domain=[0.35, 1], anchor="free", position=0.0),
            yaxis2=dict(domain=[0, 0.25], anchor="x1", dtick=1, showticklabels=False),
        )
    else:
        layout = graph_objs.Layout(
            barmode="overlay",
            hovermode="closest",
            legend=dict(traceorder="reversed"),
            xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
            yaxis1=dict(domain=[0.0, 1], anchor="free", position=0.0),
        )

    data = sum(data, [])
    return graph_objs.Figure(data=data, layout=layout)
예제 #17
0
def make_subplots(rows=1, cols=1,
                  shared_xaxes=False, shared_yaxes=False,
                  start_cell='top-left', print_grid=None,
                  **kwargs):
    """Return an instance of plotly.graph_objs.Figure
    with the subplots domain set in 'layout'.

    Example 1:
    # stack two subplots vertically
    fig = tools.make_subplots(rows=2)

    This is the format of your plot grid:
    [ (1,1) x1,y1 ]
    [ (2,1) x2,y2 ]

    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')]

    # or see Figure.append_trace

    Example 2:
    # subplots with shared x axes
    fig = tools.make_subplots(rows=2, shared_xaxes=True)

    This is the format of your plot grid:
    [ (1,1) x1,y1 ]
    [ (2,1) x1,y2 ]


    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], yaxis='y2')]

    Example 3:
    # irregular subplot layout (more examples below under 'specs')
    fig = tools.make_subplots(rows=2, cols=2,
                              specs=[[{}, {}],
                                     [{'colspan': 2}, None]])

    This is the format of your plot grid!
    [ (1,1) x1,y1 ]  [ (1,2) x2,y2 ]
    [ (2,1) x3,y3           -      ]

    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x3', yaxis='y3')]

    Example 4:
    # insets
    fig = tools.make_subplots(insets=[{'cell': (1,1), 'l': 0.7, 'b': 0.3}])

    This is the format of your plot grid!
    [ (1,1) x1,y1 ]

    With insets:
    [ x2,y2 ] over [ (1,1) x1,y1 ]

    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')]

    Example 5:
    # include subplot titles
    fig = tools.make_subplots(rows=2, subplot_titles=('Plot 1','Plot 2'))

    This is the format of your plot grid:
    [ (1,1) x1,y1 ]
    [ (2,1) x2,y2 ]

    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')]

    Example 6:
    # Include subplot title on one plot (but not all)
    fig = tools.make_subplots(insets=[{'cell': (1,1), 'l': 0.7, 'b': 0.3}],
                              subplot_titles=('','Inset'))

    This is the format of your plot grid!
    [ (1,1) x1,y1 ]

    With insets:
    [ x2,y2 ] over [ (1,1) x1,y1 ]

    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')]

    Keywords arguments with constant defaults:

    rows (kwarg, int greater than 0, default=1):
        Number of rows in the subplot grid.

    cols (kwarg, int greater than 0, default=1):
        Number of columns in the subplot grid.

    shared_xaxes (kwarg, boolean or list, default=False)
        Assign shared x axes.
        If True, subplots in the same grid column have one common
        shared x-axis at the bottom of the gird.

        To assign shared x axes per subplot grid cell (see 'specs'),
        send list (or list of lists, one list per shared x axis)
        of cell index tuples.

    shared_yaxes (kwarg, boolean or list, default=False)
        Assign shared y axes.
        If True, subplots in the same grid row have one common
        shared y-axis on the left-hand side of the gird.

        To assign shared y axes per subplot grid cell (see 'specs'),
        send list (or list of lists, one list per shared y axis)
        of cell index tuples.

    start_cell (kwarg, 'bottom-left' or 'top-left', default='top-left')
        Choose the starting cell in the subplot grid used to set the
        domains of the subplots.

    print_grid (kwarg, boolean, default=True):
        If True, prints a tab-delimited string representation of
        your plot grid.

    Keyword arguments with variable defaults:

    horizontal_spacing (kwarg, float in [0,1], default=0.2 / cols):
        Space between subplot columns.
        Applies to all columns (use 'specs' subplot-dependents spacing)

    vertical_spacing (kwarg, float in [0,1], default=0.3 / rows):
        Space between subplot rows.
        Applies to all rows (use 'specs' subplot-dependents spacing)

    subplot_titles (kwarg, list of strings, default=empty list):
        Title of each subplot.
        "" can be included in the list if no subplot title is desired in
        that space so that the titles are properly indexed.

    specs (kwarg, list of lists of dictionaries):
        Subplot specifications.

        ex1: specs=[[{}, {}], [{'colspan': 2}, None]]

        ex2: specs=[[{'rowspan': 2}, {}], [None, {}]]

        - Indices of the outer list correspond to subplot grid rows
          starting from the bottom. The number of rows in 'specs'
          must be equal to 'rows'.

        - Indices of the inner lists correspond to subplot grid columns
          starting from the left. The number of columns in 'specs'
          must be equal to 'cols'.

        - Each item in the 'specs' list corresponds to one subplot
          in a subplot grid. (N.B. The subplot grid has exactly 'rows'
          times 'cols' cells.)

        - Use None for blank a subplot cell (or to move pass a col/row span).

        - Note that specs[0][0] has the specs of the 'start_cell' subplot.

        - Each item in 'specs' is a dictionary.
            The available keys are:

            * is_3d (boolean, default=False): flag for 3d scenes
            * colspan (int, default=1): number of subplot columns
                for this subplot to span.
            * rowspan (int, default=1): number of subplot rows
                for this subplot to span.
            * l (float, default=0.0): padding left of cell
            * r (float, default=0.0): padding right of cell
            * t (float, default=0.0): padding right of cell
            * b (float, default=0.0): padding bottom of cell

        - Use 'horizontal_spacing' and 'vertical_spacing' to adjust
          the spacing in between the subplots.

    insets (kwarg, list of dictionaries):
        Inset specifications.

        - Each item in 'insets' is a dictionary.
            The available keys are:

            * cell (tuple, default=(1,1)): (row, col) index of the
                subplot cell to overlay inset axes onto.
            * is_3d (boolean, default=False): flag for 3d scenes
            * l (float, default=0.0): padding left of inset
                  in fraction of cell width
            * w (float or 'to_end', default='to_end') inset width
                  in fraction of cell width ('to_end': to cell right edge)
            * b (float, default=0.0): padding bottom of inset
                  in fraction of cell height
            * h (float or 'to_end', default='to_end') inset height
                  in fraction of cell height ('to_end': to cell top edge)

    column_width (kwarg, list of numbers)
        Column_width specifications

        - Functions similarly to `column_width` of `plotly.graph_objs.Table`.
          Specify a list that contains numbers where the amount of numbers in
          the list is equal to `cols`.

        - The numbers in the list indicate the proportions that each column
          domains take across the full horizontal domain excluding padding.

        - For example, if columns_width=[3, 1], horizontal_spacing=0, and
          cols=2, the domains for each column would be [0. 0.75] and [0.75, 1]

    row_width (kwargs, list of numbers)
        Row_width specifications

        - Functions similarly to `column_width`. Specify a list that contains
          numbers where the amount of numbers in the list is equal to `rows`.

        - The numbers in the list indicate the proportions that each row
          domains take along the full vertical domain excluding padding.

        - For example, if row_width=[3, 1], vertical_spacing=0, and
          cols=2, the domains for each row from top to botton would be
          [0. 0.75] and [0.75, 1]
    """

    from _plotly_future_ import _future_flags
    if 'v4_subplots' in _future_flags:
        import plotly.subplots
        return plotly.subplots.make_subplots(
            rows=rows,
            cols=cols,
            shared_xaxes=shared_xaxes,
            shared_yaxes=shared_yaxes,
            start_cell=start_cell,
            print_grid=print_grid,
            **kwargs
        )

    # Handle default print_grid
    if print_grid is None:
        print_grid = True

    # TODO: protected until #282
    from plotly.graph_objs import graph_objs

    # Throw exception for non-integer rows and cols
    if not isinstance(rows, int) or rows <= 0:
        raise Exception("Keyword argument 'rows' "
                        "must be an int greater than 0")
    if not isinstance(cols, int) or cols <= 0:
        raise Exception("Keyword argument 'cols' "
                        "must be an int greater than 0")

    # Dictionary of things start_cell
    START_CELL_all = {
        'bottom-left': {
            # 'natural' setup where x & y domains increase monotonically
            'col_dir': 1,
            'row_dir': 1
        },
        'top-left': {
            # 'default' setup visually matching the 'specs' list of lists
            'col_dir': 1,
            'row_dir': -1
        }
        # TODO maybe add 'bottom-right' and 'top-right'
    }

    # Throw exception for invalid 'start_cell' values
    try:
        START_CELL = START_CELL_all[start_cell]
    except KeyError:
        raise Exception("Invalid 'start_cell' value")

    # Throw exception if non-valid kwarg is sent
    VALID_KWARGS = ['horizontal_spacing', 'vertical_spacing',
                    'specs', 'insets', 'subplot_titles', 'column_width',
                    'row_width']
    for key in kwargs.keys():
        if key not in VALID_KWARGS:
            raise Exception("Invalid keyword argument: '{0}'".format(key))

    # Set 'subplot_titles'
    subplot_titles = kwargs.get('subplot_titles', [""] * rows * cols)

    # Set 'horizontal_spacing' / 'vertical_spacing' w.r.t. rows / cols
    try:
        horizontal_spacing = float(kwargs['horizontal_spacing'])
    except KeyError:
        horizontal_spacing = 0.2 / cols
    try:
        vertical_spacing = float(kwargs['vertical_spacing'])
    except KeyError:
        if 'subplot_titles' in kwargs:
            vertical_spacing = 0.5 / rows
        else:
            vertical_spacing = 0.3 / rows

    # Sanitize 'specs' (must be a list of lists)
    exception_msg = "Keyword argument 'specs' must be a list of lists"
    try:
        specs = kwargs['specs']
        if not isinstance(specs, list):
            raise Exception(exception_msg)
        else:
            for spec_row in specs:
                if not isinstance(spec_row, list):
                    raise Exception(exception_msg)
    except KeyError:
        specs = [[{}
                 for c in range(cols)]
                 for r in range(rows)]     # default 'specs'

    # Throw exception if specs is over or under specified
    if len(specs) != rows:
        raise Exception("The number of rows in 'specs' "
                        "must be equal to 'rows'")
    for r, spec_row in enumerate(specs):
        if len(spec_row) != cols:
            raise Exception("The number of columns in 'specs' "
                            "must be equal to 'cols'")

    # Sanitize 'insets'
    try:
        insets = kwargs['insets']
        if not isinstance(insets, list):
            raise Exception("Keyword argument 'insets' must be a list")
    except KeyError:
        insets = False

    # Throw exception if non-valid key / fill in defaults
    def _check_keys_and_fill(name, arg, defaults):
        def _checks(item, defaults):
            if item is None:
                return
            if not isinstance(item, dict):
                raise Exception("Items in keyword argument '{name}' must be "
                                "dictionaries or None".format(name=name))
            for k in item.keys():
                if k not in defaults.keys():
                    raise Exception("Invalid key '{k}' in keyword "
                                    "argument '{name}'".format(k=k, name=name))
            for k in defaults.keys():
                if k not in item.keys():
                    item[k] = defaults[k]
        for arg_i in arg:
            if isinstance(arg_i, list):
                for arg_ii in arg_i:
                    _checks(arg_ii, defaults)
            elif isinstance(arg_i, dict):
                _checks(arg_i, defaults)

    # Default spec key-values
    SPEC_defaults = dict(
        is_3d=False,
        colspan=1,
        rowspan=1,
        l=0.0,
        r=0.0,
        b=0.0,
        t=0.0
        # TODO add support for 'w' and 'h'
    )
    _check_keys_and_fill('specs', specs, SPEC_defaults)

    # Default inset key-values
    if insets:
        INSET_defaults = dict(
            cell=(1, 1),
            is_3d=False,
            l=0.0,
            w='to_end',
            b=0.0,
            h='to_end'
        )
        _check_keys_and_fill('insets', insets, INSET_defaults)

    # set heights (with 'column_width')
    try:
        column_width = kwargs['column_width']
        if not isinstance(column_width, list) or len(column_width) != cols:
            raise Exception(
                "Keyword argument 'column_width' must be a list with {} "
                "numbers in it, the number of subplot cols.".format(cols)
            )
    except KeyError:
        column_width = None

    if column_width:
        cum_sum = float(sum(column_width))
        widths = []
        for w in column_width:
            widths.append(
                (1. - horizontal_spacing * (cols - 1)) * (w / cum_sum)
            )
    else:
        widths = [(1. - horizontal_spacing * (cols - 1)) / cols] * cols

    # set widths (with 'row_width')
    try:
        row_width = kwargs['row_width']
        if not isinstance(row_width, list) or len(row_width) != rows:
            raise Exception(
                "Keyword argument 'row_width' must be a list with {} "
                "numbers in it, the number of subplot rows.".format(rows)
            )
    except KeyError:
        row_width = None

    if row_width:
        cum_sum = float(sum(row_width))
        heights = []
        for h in row_width:
            heights.append(
                (1. - vertical_spacing * (rows - 1)) * (h / cum_sum)
            )
    else:
        heights = [(1. - vertical_spacing * (rows - 1)) / rows] * rows

    # Built row/col sequence using 'row_dir' and 'col_dir'
    COL_DIR = START_CELL['col_dir']
    ROW_DIR = START_CELL['row_dir']
    col_seq = range(cols)[::COL_DIR]
    row_seq = range(rows)[::ROW_DIR]

    # [grid] Build subplot grid (coord tuple of cell)
    grid = [
        [
            (
                (sum(widths[:c]) + c * horizontal_spacing),
                (sum(heights[:r]) + r * vertical_spacing)
            ) for c in col_seq
        ] for r in row_seq
    ]
    # [grid_ref] Initialize the grid and insets' axis-reference lists
    grid_ref = [[None for c in range(cols)] for r in range(rows)]
    insets_ref = [None for inset in range(len(insets))] if insets else None

    layout = graph_objs.Layout()  # init layout object

    # Function handling logic around 2d axis labels
    # Returns 'x{}' | 'y{}'
    def _get_label(x_or_y, r, c, cnt, shared_axes):
        # Default label (given strictly by cnt)
        label = "{x_or_y}{cnt}".format(x_or_y=x_or_y, cnt=cnt)

        if isinstance(shared_axes, bool):
            if shared_axes:
                if x_or_y == 'x':
                    label = "{x_or_y}{c}".format(x_or_y=x_or_y, c=c + 1)
                if x_or_y == 'y':
                    label = "{x_or_y}{r}".format(x_or_y=x_or_y, r=r + 1)

        if isinstance(shared_axes, list):
            if isinstance(shared_axes[0], tuple):
                shared_axes = [shared_axes]  # TODO put this elsewhere
            for shared_axis in shared_axes:
                if (r + 1, c + 1) in shared_axis:
                    label = {
                        'x': "x{0}".format(shared_axis[0][1]),
                        'y': "y{0}".format(shared_axis[0][0])
                    }[x_or_y]

        return label

    # Row in grid of anchor row if shared_xaxes=True
    ANCHOR_ROW = 0 if ROW_DIR > 0 else rows - 1

    # Function handling logic around 2d axis anchors
    # Return 'x{}' | 'y{}' | 'free' | False
    def _get_anchors(r, c, x_cnt, y_cnt, shared_xaxes, shared_yaxes):
        # Default anchors (give strictly by cnt)
        x_anchor = "y{y_cnt}".format(y_cnt=y_cnt)
        y_anchor = "x{x_cnt}".format(x_cnt=x_cnt)

        if isinstance(shared_xaxes, bool):
            if shared_xaxes:
                if r != ANCHOR_ROW:
                    x_anchor = False
                    y_anchor = 'free'
                    if shared_yaxes and c != 0:  # TODO covers all cases?
                        y_anchor = False
                    return x_anchor, y_anchor

        elif isinstance(shared_xaxes, list):
            if isinstance(shared_xaxes[0], tuple):
                shared_xaxes = [shared_xaxes]  # TODO put this elsewhere
            for shared_xaxis in shared_xaxes:
                if (r + 1, c + 1) in shared_xaxis[1:]:
                    x_anchor = False
                    y_anchor = 'free'  # TODO covers all cases?

        if isinstance(shared_yaxes, bool):
            if shared_yaxes:
                if c != 0:
                    y_anchor = False
                    x_anchor = 'free'
                    if shared_xaxes and r != ANCHOR_ROW:  # TODO all cases?
                        x_anchor = False
                    return x_anchor, y_anchor

        elif isinstance(shared_yaxes, list):
            if isinstance(shared_yaxes[0], tuple):
                shared_yaxes = [shared_yaxes]  # TODO put this elsewhere
            for shared_yaxis in shared_yaxes:
                if (r + 1, c + 1) in shared_yaxis[1:]:
                    y_anchor = False
                    x_anchor = 'free'  # TODO covers all cases?

        return x_anchor, y_anchor

    list_of_domains = []  # added for subplot titles

    # Function pasting x/y domains in layout object (2d case)
    def _add_domain(layout, x_or_y, label, domain, anchor, position):
        name = label[0] + 'axis' + label[1:]

        # Clamp domain elements between [0, 1].
        # This is only needed to combat numerical precision errors
        # See GH1031
        axis = {'domain': [max(0.0, domain[0]), min(1.0, domain[1])]}
        if anchor:
            axis['anchor'] = anchor
        if isinstance(position, float):
            axis['position'] = position
        layout[name] = axis
        list_of_domains.append(domain)  # added for subplot titles

    # Function pasting x/y domains in layout object (3d case)
    def _add_domain_is_3d(layout, s_label, x_domain, y_domain):
        scene = dict(
            domain={'x': [max(0.0, x_domain[0]), min(1.0, x_domain[1])],
                    'y': [max(0.0, y_domain[0]), min(1.0, y_domain[1])]})
        layout[s_label] = scene

    x_cnt = y_cnt = s_cnt = 1  # subplot axis/scene counters

    # Loop through specs -- (r, c) <-> (row, col)
    for r, spec_row in enumerate(specs):
        for c, spec in enumerate(spec_row):

            if spec is None:  # skip over None cells
                continue

            c_spanned = c + spec['colspan'] - 1  # get spanned c
            r_spanned = r + spec['rowspan'] - 1  # get spanned r

            # Throw exception if 'colspan' | 'rowspan' is too large for grid
            if c_spanned >= cols:
                raise Exception("Some 'colspan' value is too large for "
                                "this subplot grid.")
            if r_spanned >= rows:
                raise Exception("Some 'rowspan' value is too large for "
                                "this subplot grid.")

            # Get x domain using grid and colspan
            x_s = grid[r][c][0] + spec['l']
            x_e = grid[r][c_spanned][0] + widths[c] - spec['r']
            x_domain = [x_s, x_e]

            # Get y domain (dep. on row_dir) using grid & r_spanned
            if ROW_DIR > 0:
                y_s = grid[r][c][1] + spec['b']
                y_e = grid[r_spanned][c][1] + heights[-1 - r] - spec['t']
            else:
                y_s = grid[r_spanned][c][1] + spec['b']
                y_e = grid[r][c][1] + heights[-1 - r] - spec['t']
            y_domain = [y_s, y_e]

            if spec['is_3d']:

                # Add scene to layout
                s_label = 'scene{0}'.format(s_cnt)
                _add_domain_is_3d(layout, s_label, x_domain, y_domain)
                grid_ref[r][c] = (s_label, )
                s_cnt += 1

            else:

                # Get axis label and anchor
                x_label = _get_label('x', r, c, x_cnt, shared_xaxes)
                y_label = _get_label('y', r, c, y_cnt, shared_yaxes)
                x_anchor, y_anchor = _get_anchors(r, c,
                                                  x_cnt, y_cnt,
                                                  shared_xaxes,
                                                  shared_yaxes)

                # Add a xaxis to layout (N.B anchor == False -> no axis)
                if x_anchor:
                    if x_anchor == 'free':
                        x_position = y_domain[0]
                    else:
                        x_position = False
                    _add_domain(layout, 'x', x_label, x_domain,
                                x_anchor, x_position)
                    x_cnt += 1

                # Add a yaxis to layout (N.B anchor == False -> no axis)
                if y_anchor:
                    if y_anchor == 'free':
                        y_position = x_domain[0]
                    else:
                        y_position = False
                    _add_domain(layout, 'y', y_label, y_domain,
                                y_anchor, y_position)
                    y_cnt += 1

                grid_ref[r][c] = (x_label, y_label)  # fill in ref

    # Loop through insets
    if insets:
        for i_inset, inset in enumerate(insets):

            r = inset['cell'][0] - 1
            c = inset['cell'][1] - 1

            # Throw exception if r | c is out of range
            if not (0 <= r < rows):
                raise Exception("Some 'cell' row value is out of range. "
                                "Note: the starting cell is (1, 1)")
            if not (0 <= c < cols):
                raise Exception("Some 'cell' col value is out of range. "
                                "Note: the starting cell is (1, 1)")

            # Get inset x domain using grid
            x_s = grid[r][c][0] + inset['l'] * widths[c]
            if inset['w'] == 'to_end':
                x_e = grid[r][c][0] + widths[c]
            else:
                x_e = x_s + inset['w'] * widths[c]
            x_domain = [x_s, x_e]

            # Get inset y domain using grid
            y_s = grid[r][c][1] + inset['b'] * heights[-1 - r]
            if inset['h'] == 'to_end':
                y_e = grid[r][c][1] + heights[-1 - r]
            else:
                y_e = y_s + inset['h'] * heights[-1 - r]
            y_domain = [y_s, y_e]

            if inset['is_3d']:

                # Add scene to layout
                s_label = 'scene{0}'.format(s_cnt)
                _add_domain_is_3d(layout, s_label, x_domain, y_domain)
                insets_ref[i_inset] = (s_label, )
                s_cnt += 1

            else:

                # Get axis label and anchor
                x_label = _get_label('x', False, False, x_cnt, False)
                y_label = _get_label('y', False, False, y_cnt, False)
                x_anchor, y_anchor = _get_anchors(r, c,
                                                  x_cnt, y_cnt,
                                                  False, False)

                # Add a xaxis to layout (N.B insets always have anchors)
                _add_domain(layout, 'x', x_label, x_domain, x_anchor, False)
                x_cnt += 1

                # Add a yayis to layout (N.B insets always have anchors)
                _add_domain(layout, 'y', y_label, y_domain, y_anchor, False)
                y_cnt += 1

                insets_ref[i_inset] = (x_label, y_label)  # fill in ref

    # [grid_str] Set the grid's string representation
    sp = "  "            # space between cell
    s_str = "[ "         # cell start string
    e_str = " ]"         # cell end string
    colspan_str = '       -'     # colspan string
    rowspan_str = '       |'     # rowspan string
    empty_str = '    (empty) '   # empty cell string

    # Init grid_str with intro message
    grid_str = "This is the format of your plot grid:\n"

    # Init tmp list of lists of strings (sorta like 'grid_ref' but w/ strings)
    _tmp = [['' for c in range(cols)] for r in range(rows)]

    # Define cell string as function of (r, c) and grid_ref
    def _get_cell_str(r, c, ref):
        return '({r},{c}) {ref}'.format(r=r + 1, c=c + 1, ref=','.join(ref))

    # Find max len of _cell_str, add define a padding function
    cell_len = max([len(_get_cell_str(r, c, ref))
                    for r, row_ref in enumerate(grid_ref)
                    for c, ref in enumerate(row_ref)
                    if ref]) + len(s_str) + len(e_str)

    def _pad(s, cell_len=cell_len):
        return ' ' * (cell_len - len(s))

    # Loop through specs, fill in _tmp
    for r, spec_row in enumerate(specs):
        for c, spec in enumerate(spec_row):

            ref = grid_ref[r][c]
            if ref is None:
                if _tmp[r][c] == '':
                    _tmp[r][c] = empty_str + _pad(empty_str)
                continue

            cell_str = s_str + _get_cell_str(r, c, ref)

            if spec['colspan'] > 1:
                for cc in range(1, spec['colspan'] - 1):
                    _tmp[r][c + cc] = colspan_str + _pad(colspan_str)
                _tmp[r][c + spec['colspan'] - 1] = (
                    colspan_str + _pad(colspan_str + e_str)) + e_str
            else:
                cell_str += e_str

            if spec['rowspan'] > 1:
                for rr in range(1, spec['rowspan'] - 1):
                    _tmp[r + rr][c] = rowspan_str + _pad(rowspan_str)
                for cc in range(spec['colspan']):
                    _tmp[r + spec['rowspan'] - 1][c + cc] = (
                        rowspan_str + _pad(rowspan_str))

            _tmp[r][c] = cell_str + _pad(cell_str)

    # Append grid_str using data from _tmp in the correct order
    for r in row_seq[::-1]:
        grid_str += sp.join(_tmp[r]) + '\n'

    # Append grid_str to include insets info
    if insets:
        grid_str += "\nWith insets:\n"
        for i_inset, inset in enumerate(insets):

            r = inset['cell'][0] - 1
            c = inset['cell'][1] - 1
            ref = grid_ref[r][c]

            grid_str += (
                s_str + ','.join(insets_ref[i_inset]) + e_str +
                ' over ' +
                s_str + _get_cell_str(r, c, ref) + e_str + '\n'
            )

    # Add subplot titles

    # If shared_axes is False (default) use list_of_domains
    # This is used for insets and irregular layouts
    if not shared_xaxes and not shared_yaxes:
        x_dom = list_of_domains[::2]
        y_dom = list_of_domains[1::2]
        subtitle_pos_x = []
        subtitle_pos_y = []
        for x_domains in x_dom:
            subtitle_pos_x.append(sum(x_domains) / 2)
        for y_domains in y_dom:
            subtitle_pos_y.append(y_domains[1])

    # If shared_axes is True the domin of each subplot is not returned so the
    # title position must be calculated for each subplot
    else:
        x_dom_vals = [k for k in layout.to_plotly_json().keys() if 'xaxis' in k]
        y_dom_vals = [k for k in layout.to_plotly_json().keys() if 'yaxis' in k]

        # sort xaxis and yaxis layout keys
        r = re.compile('\d+')

        def key_func(m):
            try:
                return int(r.search(m).group(0))
            except AttributeError:
                return 0

        xaxies_labels_sorted = sorted(x_dom_vals, key=key_func)
        yaxies_labels_sorted = sorted(y_dom_vals, key=key_func)

        x_dom = [layout[k]['domain'] for k in xaxies_labels_sorted]
        y_dom = [layout[k]['domain'] for k in yaxies_labels_sorted]

        for index in range(cols):
            subtitle_pos_x = []
            for x_domains in x_dom:
                subtitle_pos_x.append(sum(x_domains) / 2)
            subtitle_pos_x *= rows

        if shared_yaxes:
            for index in range(rows):
                subtitle_pos_y = []
                for y_domain in y_dom:
                    subtitle_pos_y.append(y_domain[1])
                subtitle_pos_y *= cols
            subtitle_pos_y = sorted(subtitle_pos_y, reverse=True)

        else:
            for index in range(rows):
                subtitle_pos_y = []
                for y_domain in y_dom:
                    subtitle_pos_y.append(y_domain[1])
            subtitle_pos_y = sorted(subtitle_pos_y, reverse=True)
            subtitle_pos_y *= cols

    plot_titles = []
    for index in range(len(subplot_titles)):
        if not subplot_titles[index]:
            pass
        else:
            plot_titles.append({'y': subtitle_pos_y[index],
                                'xref': 'paper',
                                'x': subtitle_pos_x[index],
                                'yref': 'paper',
                                'text': subplot_titles[index],
                                'showarrow': False,
                                'font': dict(size=16),
                                'xanchor': 'center',
                                'yanchor': 'bottom'
                                })

            layout['annotations'] = plot_titles

    if print_grid:
        print(grid_str)

    fig = graph_objs.Figure(layout=layout)

    fig.__dict__['_grid_ref'] = grid_ref
    fig.__dict__['_grid_str'] = grid_str

    return fig
def create_trisurf(x,
                   y,
                   z,
                   simplices,
                   colormap=None,
                   show_colorbar=True,
                   scale=None,
                   color_func=None,
                   title='Trisurf Plot',
                   plot_edges=True,
                   showbackground=True,
                   backgroundcolor='rgb(230, 230, 230)',
                   gridcolor='rgb(255, 255, 255)',
                   zerolinecolor='rgb(255, 255, 255)',
                   edges_color='rgb(50, 50, 50)',
                   height=800,
                   width=800,
                   aspectratio=None):
    """
    Returns figure for a triangulated surface plot

    :param (array) x: data values of x in a 1D array
    :param (array) y: data values of y in a 1D array
    :param (array) z: data values of z in a 1D array
    :param (array) simplices: an array of shape (ntri, 3) where ntri is
        the number of triangles in the triangularization. Each row of the
        array contains the indicies of the verticies of each triangle
    :param (str|tuple|list) colormap: either a plotly scale name, an rgb
        or hex color, a color tuple or a list of colors. An rgb color is
        of the form 'rgb(x, y, z)' where x, y, z belong to the interval
        [0, 255] and a color tuple is a tuple of the form (a, b, c) where
        a, b and c belong to [0, 1]. If colormap is a list, it must
        contain the valid color types aforementioned as its members
    :param (bool) show_colorbar: determines if colorbar is visible
    :param (list|array) scale: sets the scale values to be used if a non-
        linearly interpolated colormap is desired. If left as None, a
        linear interpolation between the colors will be excecuted
    :param (function|list) color_func: The parameter that determines the
        coloring of the surface. Takes either a function with 3 arguments
        x, y, z or a list/array of color values the same length as
        simplices. If None, coloring will only depend on the z axis
    :param (str) title: title of the plot
    :param (bool) plot_edges: determines if the triangles on the trisurf
        are visible
    :param (bool) showbackground: makes background in plot visible
    :param (str) backgroundcolor: color of background. Takes a string of
        the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
    :param (str) gridcolor: color of the gridlines besides the axes. Takes
        a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255
        inclusive
    :param (str) zerolinecolor: color of the axes. Takes a string of the
        form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
    :param (str) edges_color: color of the edges, if plot_edges is True
    :param (int|float) height: the height of the plot (in pixels)
    :param (int|float) width: the width of the plot (in pixels)
    :param (dict) aspectratio: a dictionary of the aspect ratio values for
        the x, y and z axes. 'x', 'y' and 'z' take (int|float) values

    Example 1: Sphere
    ```
    # Necessary Imports for Trisurf
    import numpy as np
    from scipy.spatial import Delaunay

    import plotly.plotly as py
    from plotly.figure_factory import create_trisurf
    from plotly.graph_objs import graph_objs

    # Make data for plot
    u = np.linspace(0, 2*np.pi, 20)
    v = np.linspace(0, np.pi, 20)
    u,v = np.meshgrid(u,v)
    u = u.flatten()
    v = v.flatten()

    x = np.sin(v)*np.cos(u)
    y = np.sin(v)*np.sin(u)
    z = np.cos(v)

    points2D = np.vstack([u,v]).T
    tri = Delaunay(points2D)
    simplices = tri.simplices

    # Create a figure
    fig1 = create_trisurf(x=x, y=y, z=z, colormap="Rainbow",
                          simplices=simplices)
    # Plot the data
    py.iplot(fig1, filename='trisurf-plot-sphere')
    ```

    Example 2: Torus
    ```
    # Necessary Imports for Trisurf
    import numpy as np
    from scipy.spatial import Delaunay

    import plotly.plotly as py
    from plotly.figure_factory import create_trisurf
    from plotly.graph_objs import graph_objs

    # Make data for plot
    u = np.linspace(0, 2*np.pi, 20)
    v = np.linspace(0, 2*np.pi, 20)
    u,v = np.meshgrid(u,v)
    u = u.flatten()
    v = v.flatten()

    x = (3 + (np.cos(v)))*np.cos(u)
    y = (3 + (np.cos(v)))*np.sin(u)
    z = np.sin(v)

    points2D = np.vstack([u,v]).T
    tri = Delaunay(points2D)
    simplices = tri.simplices

    # Create a figure
    fig1 = create_trisurf(x=x, y=y, z=z, colormap="Viridis",
                          simplices=simplices)
    # Plot the data
    py.iplot(fig1, filename='trisurf-plot-torus')
    ```

    Example 3: Mobius Band
    ```
    # Necessary Imports for Trisurf
    import numpy as np
    from scipy.spatial import Delaunay

    import plotly.plotly as py
    from plotly.figure_factory import create_trisurf
    from plotly.graph_objs import graph_objs

    # Make data for plot
    u = np.linspace(0, 2*np.pi, 24)
    v = np.linspace(-1, 1, 8)
    u,v = np.meshgrid(u,v)
    u = u.flatten()
    v = v.flatten()

    tp = 1 + 0.5*v*np.cos(u/2.)
    x = tp*np.cos(u)
    y = tp*np.sin(u)
    z = 0.5*v*np.sin(u/2.)

    points2D = np.vstack([u,v]).T
    tri = Delaunay(points2D)
    simplices = tri.simplices

    # Create a figure
    fig1 = create_trisurf(x=x, y=y, z=z, colormap=[(0.2, 0.4, 0.6), (1, 1, 1)],
                          simplices=simplices)
    # Plot the data
    py.iplot(fig1, filename='trisurf-plot-mobius-band')
    ```

    Example 4: Using a Custom Colormap Function with Light Cone
    ```
    # Necessary Imports for Trisurf
    import numpy as np
    from scipy.spatial import Delaunay

    import plotly.plotly as py
    from plotly.figure_factory import create_trisurf
    from plotly.graph_objs import graph_objs

    # Make data for plot
    u=np.linspace(-np.pi, np.pi, 30)
    v=np.linspace(-np.pi, np.pi, 30)
    u,v=np.meshgrid(u,v)
    u=u.flatten()
    v=v.flatten()

    x = u
    y = u*np.cos(v)
    z = u*np.sin(v)

    points2D = np.vstack([u,v]).T
    tri = Delaunay(points2D)
    simplices = tri.simplices

    # Define distance function
    def dist_origin(x, y, z):
        return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2)

    # Create a figure
    fig1 = create_trisurf(x=x, y=y, z=z,
                          colormap=['#FFFFFF', '#E4FFFE',
                                    '#A4F6F9', '#FF99FE',
                                    '#BA52ED'],
                          scale=[0, 0.6, 0.71, 0.89, 1],
                          simplices=simplices,
                          color_func=dist_origin)
    # Plot the data
    py.iplot(fig1, filename='trisurf-plot-custom-coloring')
    ```

    Example 5: Enter color_func as a list of colors
    ```
    # Necessary Imports for Trisurf
    import numpy as np
    from scipy.spatial import Delaunay
    import random

    import plotly.plotly as py
    from plotly.figure_factory import create_trisurf
    from plotly.graph_objs import graph_objs

    # Make data for plot
    u=np.linspace(-np.pi, np.pi, 30)
    v=np.linspace(-np.pi, np.pi, 30)
    u,v=np.meshgrid(u,v)
    u=u.flatten()
    v=v.flatten()

    x = u
    y = u*np.cos(v)
    z = u*np.sin(v)

    points2D = np.vstack([u,v]).T
    tri = Delaunay(points2D)
    simplices = tri.simplices


    colors = []
    color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd']

    for index in range(len(simplices)):
        colors.append(random.choice(color_choices))

    fig = create_trisurf(
        x, y, z, simplices,
        color_func=colors,
        show_colorbar=True,
        edges_color='rgb(2, 85, 180)',
        title=' Modern Art'
    )

    py.iplot(fig, filename="trisurf-plot-modern-art")
    ```
    """
    if aspectratio is None:
        aspectratio = {'x': 1, 'y': 1, 'z': 1}

    # Validate colormap
    colors.validate_colors(colormap)
    colormap, scale = colors.convert_colors_to_same_type(
        colormap, colortype='tuple', return_default_colors=True, scale=scale)

    data1 = trisurf(x,
                    y,
                    z,
                    simplices,
                    show_colorbar=show_colorbar,
                    color_func=color_func,
                    colormap=colormap,
                    scale=scale,
                    edges_color=edges_color,
                    plot_edges=plot_edges)

    axis = dict(
        showbackground=showbackground,
        backgroundcolor=backgroundcolor,
        gridcolor=gridcolor,
        zerolinecolor=zerolinecolor,
    )
    layout = graph_objs.Layout(title=title,
                               width=width,
                               height=height,
                               scene=graph_objs.Scene(
                                   xaxis=graph_objs.XAxis(axis),
                                   yaxis=graph_objs.YAxis(axis),
                                   zaxis=graph_objs.ZAxis(axis),
                                   aspectratio=dict(x=aspectratio['x'],
                                                    y=aspectratio['y'],
                                                    z=aspectratio['z']),
                               ))

    return graph_objs.Figure(data=data1, layout=layout)
예제 #19
0
def create_candlestick(open, high, low, close, dates=None, direction='both',
                       **kwargs):
    # """
    # BETA function that creates a candlestick chart
    #
    # :param (list) open: opening values
    # :param (list) high: high values
    # :param (list) low: low values
    # :param (list) close: closing values
    # :param (list) dates: list of datetime objects. Default: None
    # :param (string) direction: direction can be 'increasing', 'decreasing',
    #     or 'both'. When the direction is 'increasing', the returned figure
    #     consists of all candlesticks where the close value is greater than
    #     the corresponding open value, and when the direction is
    #     'decreasing', the returned figure consists of all candlesticks
    #     where the close value is less than or equal to the corresponding
    #     open value. When the direction is 'both', both increasing and
    #     decreasing candlesticks are returned. Default: 'both'
    # :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
    #     These kwargs describe other attributes about the ohlc Scatter trace
    #     such as the color or the legend name. For more information on valid
    #     kwargs call help(plotly.graph_objs.Scatter)
    #
    # :rtype (dict): returns a representation of candlestick chart figure.
    #
    # Example 1: Simple candlestick chart from a Pandas DataFrame
    # ```
    # import plotly.plotly as py
    # from plotly.figure_factory import create_candlestick
    # from datetime import datetime
    #
    # import pandas.io.data as web
    #
    # df = web.DataReader("aapl", 'yahoo', datetime(2007, 10, 1), datetime(2009, 4, 1))
    # fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index)
    # py.plot(fig, filename='finance/aapl-candlestick', validate=False)
    # ```
    #
    # Example 2: Add text and annotations to the candlestick chart
    # ```
    # fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index)
    # # Update the fig - all options here: https://plot.ly/python/reference/#Layout
    # fig['layout'].update({
    #     'title': 'The Great Recession',
    #     'yaxis': {'title': 'AAPL Stock'},
    #     'shapes': [{
    #         'x0': '2007-12-01', 'x1': '2007-12-01',
    #         'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper',
    #         'line': {'color': 'rgb(30,30,30)', 'width': 1}
    #     }],
    #     'annotations': [{
    #         'x': '2007-12-01', 'y': 0.05, 'xref': 'x', 'yref': 'paper',
    #         'showarrow': False, 'xanchor': 'left',
    #         'text': 'Official start of the recession'
    #     }]
    # })
    # py.plot(fig, filename='finance/aapl-recession-candlestick', validate=False)
    # ```
    #
    # Example 3: Customize the candlestick colors
    # ```
    # import plotly.plotly as py
    # from plotly.figure_factory import create_candlestick
    # from plotly.graph_objs import Line, Marker
    # from datetime import datetime
    #
    # import pandas.io.data as web
    #
    # df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1))
    #
    # # Make increasing candlesticks and customize their color and name
    # fig_increasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index,
    #     direction='increasing', name='AAPL',
    #     marker=Marker(color='rgb(150, 200, 250)'),
    #     line=Line(color='rgb(150, 200, 250)'))
    #
    # # Make decreasing candlesticks and customize their color and name
    # fig_decreasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index,
    #     direction='decreasing',
    #     marker=Marker(color='rgb(128, 128, 128)'),
    #     line=Line(color='rgb(128, 128, 128)'))
    #
    # # Initialize the figure
    # fig = fig_increasing
    #
    # # Add decreasing data with .extend()
    # fig['data'].extend(fig_decreasing['data'])
    #
    # py.iplot(fig, filename='finance/aapl-candlestick-custom', validate=False)
    # ```
    #
    # Example 4: Candlestick chart with datetime objects
    # ```
    # import plotly.plotly as py
    # from plotly.figure_factory import create_candlestick
    #
    # from datetime import datetime
    #
    # # Add data
    # open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
    # high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
    # low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
    # close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
    # dates = [datetime(_year=2013, month=10, day=10),
    #          datetime(_year=2013, month=11, day=10),
    #          datetime(_year=2013, month=12, day=10),
    #          datetime(_year=2014, month=1, day=10),
    #          datetime(_year=2014, month=2, day=10)]
    #
    # # Create ohlc
    # fig = create_candlestick(open_data, high_data,
    #     low_data, close_data, dates=dates)
    #
    # py.iplot(fig, filename='finance/simple-candlestick', validate=False)
    # ```
    # """

    # if dates is not None:
    #     utils.validate_equal_length(open, high, low, close, dates)
    # else:
    #     utils.validate_equal_length(open, high, low, close)
    # validate_ohlc(open, high, low, close, direction, **kwargs)

    if direction is 'increasing':
        candle_incr_data = make_increasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        data = candle_incr_data
    elif direction is 'decreasing':
        candle_decr_data = make_decreasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        data = candle_decr_data
    else:
        candle_incr_data = make_increasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        candle_decr_data = make_decreasing_candle(open, high, low, close,
                                                  dates, **kwargs)
        data = candle_incr_data + candle_decr_data

    layout = graph_objs.Layout()
    return graph_objs.Figure(data=data, layout=layout)
예제 #20
0
def gantt(chart,
          colors,
          title,
          bar_width,
          showgrid_x,
          showgrid_y,
          height,
          width,
          tasks=None,
          task_names=None,
          data=None,
          group_tasks=False):
    """
    Refer to create_gantt() for docstring
    """
    if tasks is None:
        tasks = []
    if task_names is None:
        task_names = []
    if data is None:
        data = []

    for index in range(len(chart)):
        task = dict(x0=chart[index]['Start'],
                    x1=chart[index]['Finish'],
                    name=chart[index]['Task'])
        if 'Description' in chart[index]:
            task['description'] = chart[index]['Description']
        tasks.append(task)

    shape_template = {
        'type': 'rect',
        'xref': 'x',
        'yref': 'y',
        'opacity': 1,
        'line': {
            'width': 0,
        }
    }
    # create the list of task names
    for index in range(len(tasks)):
        tn = tasks[index]['name']
        # Is added to task_names if group_tasks is set to False,
        # or if the option is used (True) it only adds them if the
        # name is not already in the list
        if not group_tasks or tn not in task_names:
            task_names.append(tn)
    # Guarantees that for grouped tasks the tasks that are inserted first
    # are shown at the top
    if group_tasks:
        task_names.reverse()

    color_index = 0
    for index in range(len(tasks)):
        tn = tasks[index]['name']
        del tasks[index]['name']
        tasks[index].update(shape_template)

        # If group_tasks is True, all tasks with the same name belong
        # to the same row.
        groupID = index
        if group_tasks:
            groupID = task_names.index(tn)
        tasks[index]['y0'] = groupID - bar_width
        tasks[index]['y1'] = groupID + bar_width

        # check if colors need to be looped
        if color_index >= len(colors):
            color_index = 0
        tasks[index]['fillcolor'] = colors[color_index]
        # Add a line for hover text and autorange
        entry = dict(x=[tasks[index]['x0'], tasks[index]['x1']],
                     y=[groupID, groupID],
                     name='',
                     marker={'color': 'white'})
        if "description" in tasks[index]:
            entry['text'] = tasks[index]['description']
            del tasks[index]['description']
        data.append(entry)
        color_index += 1

    layout = dict(
        title=title,
        showlegend=False,
        height=height,
        width=width,
        shapes=[],
        hovermode='closest',
        yaxis=dict(
            showgrid=showgrid_y,
            ticktext=task_names,
            tickvals=list(range(len(task_names))),
            range=[-1, len(task_names) + 1],
            autorange=False,
            zeroline=False,
        ),
        xaxis=dict(
            showgrid=showgrid_x,
            zeroline=False,
            rangeselector=dict(buttons=list([
                dict(count=7, label='1w', step='day', stepmode='backward'),
                dict(count=1, label='1m', step='month', stepmode='backward'),
                dict(count=6, label='6m', step='month', stepmode='backward'),
                dict(count=1, label='YTD', step='year', stepmode='todate'),
                dict(count=1, label='1y', step='year', stepmode='backward'),
                dict(step='all')
            ])),
            type='date'))
    layout['shapes'] = tasks

    fig = graph_objs.Figure(data=data, layout=layout)
    return fig
예제 #21
0
def create_table(table_text,
                 colorscale=None,
                 font_colors=None,
                 index=False,
                 index_title="",
                 annotation_offset=0.45,
                 height_constant=30,
                 hoverinfo="none",
                 **kwargs):
    """
    BETA function that creates data tables

    :param (pandas.Dataframe | list[list]) text: data for table.
    :param (str|list[list]) colorscale: Colorscale for table where the
        color at value 0 is the header color, .5 is the first table color
        and 1 is the second table color. (Set .5 and 1 to avoid the striped
        table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'],
        [1, '#ffffff']]
    :param (list) font_colors: Color for fonts in table. Can be a single
        color, three colors, or a color for each row in the table.
        Default=['#000000'] (black text for the entire table)
    :param (int) height_constant: Constant multiplied by # of rows to
        create table height. Default=30.
    :param (bool) index: Create (header-colored) index column index from
        Pandas dataframe or list[0] for each list in text. Default=False.
    :param (string) index_title: Title for index column. Default=''.
    :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
        These kwargs describe other attributes about the annotated Heatmap
        trace such as the colorscale. For more information on valid kwargs
        call help(plotly.graph_objs.Heatmap)

    Example 1: Simple Plotly Table

    >>> from plotly.figure_factory import create_table

    >>> text = [['Country', 'Year', 'Population'],
    ...         ['US', 2000, 282200000],
    ...         ['Canada', 2000, 27790000],
    ...         ['US', 2010, 309000000],
    ...         ['Canada', 2010, 34000000]]

    >>> table = create_table(text)
    >>> table.show()

    Example 2: Table with Custom Coloring

    >>> from plotly.figure_factory import create_table

    >>> text = [['Country', 'Year', 'Population'],
    ...         ['US', 2000, 282200000],
    ...         ['Canada', 2000, 27790000],
    ...         ['US', 2010, 309000000],
    ...         ['Canada', 2010, 34000000]]

    >>> table = create_table(text,
    ...                      colorscale=[[0, '#000000'],
    ...                                  [.5, '#80beff'],
    ...                                  [1, '#cce5ff']],
    ...                      font_colors=['#ffffff', '#000000',
    ...                                 '#000000'])
    >>> table.show()

    Example 3: Simple Plotly Table with Pandas

    >>> from plotly.figure_factory import create_table
    >>> import pandas as pd

    >>> df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t')
    >>> df_p = df[0:25]

    >>> table_simple = create_table(df_p)
    >>> table_simple.show()
    """

    # Avoiding mutables in the call signature
    colorscale = (colorscale if colorscale is not None else
                  [[0, "#00083e"], [0.5, "#ededee"], [1, "#ffffff"]])
    font_colors = (font_colors if font_colors is not None else
                   ["#ffffff", "#000000", "#000000"])

    validate_table(table_text, font_colors)
    table_matrix = _Table(table_text, colorscale, font_colors, index,
                          index_title, annotation_offset,
                          **kwargs).get_table_matrix()
    annotations = _Table(table_text, colorscale, font_colors, index,
                         index_title, annotation_offset,
                         **kwargs).make_table_annotations()

    trace = dict(type="heatmap",
                 z=table_matrix,
                 opacity=0.75,
                 colorscale=colorscale,
                 showscale=False,
                 hoverinfo=hoverinfo,
                 **kwargs)

    data = [trace]
    layout = dict(
        annotations=annotations,
        height=len(table_matrix) * height_constant + 50,
        margin=dict(t=0, b=0, r=0, l=0),
        yaxis=dict(
            autorange="reversed",
            zeroline=False,
            gridwidth=2,
            ticks="",
            dtick=1,
            tick0=0.5,
            showticklabels=False,
        ),
        xaxis=dict(
            zeroline=False,
            gridwidth=2,
            ticks="",
            dtick=1,
            tick0=-0.5,
            showticklabels=False,
        ),
    )
    return graph_objs.Figure(data=data, layout=layout)
예제 #22
0
파일: tools.py 프로젝트: bmswgnp/python-api
def make_subplots(rows=1, cols=1,
                  shared_xaxes=False, shared_yaxes=False,
                  start_cell='top-left', print_grid=True,
                  **kwargs):
    """Return an instance of plotly.graph_objs.Figure
    with the subplots domain set in 'layout'.

    Example 1:
    # stack two subplots vertically
    fig = tools.make_subplots(rows=2)

    This is the format of your plot grid:
    [ (1,1) x1,y1 ]
    [ (2,1) x2,y2 ]

    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')]

    # or see Figure.append_trace

    Example 2:
    # subplots with shared x axes
    fig = tools.make_subplots(rows=2, shared_xaxes=True)

    This is the format of your plot grid:
    [ (1,1) x1,y1 ]
    [ (2,1) x1,y2 ]


    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], yaxis='y2')]

    Example 3:
    # irregular subplot layout (more examples below under 'specs')
    fig = tools.make_subplots(rows=2, cols=2,
                              specs=[[{}, {}],
                                     [{'colspan': 2}, None]])

    This is the format of your plot grid!
    [ (1,1) x1,y1 ]  [ (1,2) x2,y2 ]
    [ (2,1) x3,y3           -      ]

    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x3', yaxis='y3')]

    Example 4:
    # insets
    fig = tools.make_subplots(insets=[{'cell': (1,1), 'l': 0.7, 'b': 0.3}])

    This is the format of your plot grid!
    [ (1,1) x1,y1 ]

    With insets:
    [ x2,y2 ] over [ (1,1) x1,y1 ]

    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])]
    fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')]

    Keywords arguments with constant defaults:

    rows (kwarg, int greater than 0, default=1):
        Number of rows in the subplot grid.

    cols (kwarg, int greater than 0, default=1):
        Number of columns in the subplot grid.

    shared_xaxes (kwarg, boolean or list, default=False)
        Assign shared x axes.
        If True, subplots in the same grid column have one common
        shared x-axis at the bottom of the gird.

        To assign shared x axes per subplot grid cell (see 'specs'),
        send list (or list of lists, one list per shared x axis)
        of cell index tuples.

    shared_yaxes (kwarg, boolean or list, default=False)
        Assign shared y axes.
        If True, subplots in the same grid row have one common
        shared y-axis on the left-hand side of the gird.

        To assign shared y axes per subplot grid cell (see 'specs'),
        send list (or list of lists, one list per shared y axis)
        of cell index tuples.

    start_cell (kwarg, 'bottom-left' or 'top-left', default='top-left')
        Choose the starting cell in the subplot grid used to set the
        domains of the subplots.

    print_grid (kwarg, boolean, default=True):
        If True, prints a tab-delimited string representation of
        your plot grid.

    Keyword arguments with variable defaults:

    horizontal_spacing (kwarg, float in [0,1], default=0.2 / cols):
        Space between subplot columns.
        Applies to all columns (use 'specs' subplot-dependents spacing)

    vertical_spacing (kwarg, float in [0,1], default=0.3 / rows):
        Space between subplot rows.
        Applies to all rows (use 'specs' subplot-dependents spacing)

    specs (kwarg, list of lists of dictionaries):
        Subplot specifications.

        ex1: specs=[[{}, {}], [{'colspan': 2}, None]]

        ex2: specs=[[{'rowspan': 2}, {}], [None, {}]]

        - Indices of the outer list correspond to subplot grid rows
          starting from the bottom. The number of rows in 'specs'
          must be equal to 'rows'.

        - Indices of the inner lists correspond to subplot grid columns
          starting from the left. The number of columns in 'specs'
          must be equal to 'cols'.

        - Each item in the 'specs' list corresponds to one subplot
          in a subplot grid. (N.B. The subplot grid has exactly 'rows'
          times 'cols' cells.)

        - Use None for blank a subplot cell (or to move pass a col/row span).

        - Note that specs[0][0] has the specs of the 'start_cell' subplot.

        - Each item in 'specs' is a dictionary.
            The available keys are:

            * is_3d (boolean, default=False): flag for 3d scenes
            * colspan (int, default=1): number of subplot columns
                for this subplot to span.
            * rowspan (int, default=1): number of subplot rows
                for this subplot to span.
            * l (float, default=0.0): padding left of cell
            * r (float, default=0.0): padding right of cell
            * t (float, default=0.0): padding right of cell
            * b (float, default=0.0): padding bottom of cell

        - Use 'horizontal_spacing' and 'vertical_spacing' to adjust
          the spacing in between the subplots.

    insets (kwarg, list of dictionaries):
        Inset specifications.

        - Each item in 'insets' is a dictionary.
            The available keys are:

            * cell (tuple, default=(1,1)): (row, col) index of the
                subplot cell to overlay inset axes onto.
            * is_3d (boolean, default=False): flag for 3d scenes
            * l (float, default=0.0): padding left of inset
                  in fraction of cell width
            * w (float or 'to_end', default='to_end') inset width
                  in fraction of cell width ('to_end': to cell right edge)
            * b (float, default=0.0): padding bottom of inset
                  in fraction of cell height
            * h (float or 'to_end', default='to_end') inset height
                  in fraction of cell height ('to_end': to cell top edge)
    """

    # Throw exception for non-integer rows and cols
    if not isinstance(rows, int) or rows <= 0:
        raise Exception("Keyword argument 'rows' "
                        "must be an int greater than 0")
    if not isinstance(cols, int) or cols <= 0:
        raise Exception("Keyword argument 'cols' "
                        "must be an int greater than 0")

    # Dictionary of things start_cell
    START_CELL_all = {
        'bottom-left': {
            # 'natural' setup where x & y domains increase monotonically
            'col_dir': 1,
            'row_dir': 1
        },
        'top-left': {
            # 'default' setup visually matching the 'specs' list of lists
            'col_dir': 1,
            'row_dir': -1
        }
        # TODO maybe add 'bottom-right' and 'top-right'
    }

    # Throw exception for invalid 'start_cell' values
    try:
        START_CELL = START_CELL_all[start_cell]
    except KeyError:
        raise Exception("Invalid 'start_cell' value")

    # Throw exception if non-valid kwarg is sent
    VALID_KWARGS = ['horizontal_spacing', 'vertical_spacing',
                    'specs', 'insets']
    for key in kwargs.keys():
        if key not in VALID_KWARGS:
            raise Exception("Invalid keyword argument: '{0}'".format(key))

    # Set 'horizontal_spacing' / 'vertical_spacing' w.r.t. rows / cols
    try:
        horizontal_spacing = float(kwargs['horizontal_spacing'])
    except KeyError:
        horizontal_spacing = 0.2 / cols
    try:
        vertical_spacing = float(kwargs['vertical_spacing'])
    except KeyError:
        vertical_spacing = 0.3 / rows

    # Sanitize 'specs' (must be a list of lists)
    exception_msg = "Keyword argument 'specs' must be a list of lists"
    try:
        specs = kwargs['specs']
        if not isinstance(specs, list):
            raise Exception(exception_msg)
        else:
            for spec_row in specs:
                if not isinstance(spec_row, list):
                    raise Exception(exception_msg)
    except KeyError:
        specs = [[{}
                 for c in range(cols)]
                 for r in range(rows)]     # default 'specs'

    # Throw exception if specs is over or under specified
    if len(specs) != rows:
        raise Exception("The number of rows in 'specs' "
                        "must be equal to 'rows'")
    for r, spec_row in enumerate(specs):
        if len(spec_row) != cols:
            raise Exception("The number of columns in 'specs' "
                            "must be equal to 'cols'")

    # Sanitize 'insets'
    try:
        insets = kwargs['insets']
        if not isinstance(insets, list):
            raise Exception("Keyword argument 'insets' must be a list")
    except KeyError:
        insets = False

    # Throw exception if non-valid key / fill in defaults
    def _check_keys_and_fill(name, arg, defaults):
        def _checks(item, defaults):
            if item is None:
                return
            if not isinstance(item, dict):
                raise Exception("Items in keyword argument '{name}' must be "
                                "dictionaries or None".format(name=name))
            for k in item.keys():
                if k not in defaults.keys():
                    raise Exception("Invalid key '{k}' in keyword "
                                    "argument '{name}'".format(k=k, name=name))
            for k in defaults.keys():
                if k not in item.keys():
                    item[k] = defaults[k]
        for arg_i in arg:
            if isinstance(arg_i, list):
                for arg_ii in arg_i:
                    _checks(arg_ii, defaults)
            elif isinstance(arg_i, dict):
                _checks(arg_i, defaults)

    # Default spec key-values
    SPEC_defaults = dict(
        is_3d=False,
        colspan=1,
        rowspan=1,
        l=0.0,
        r=0.0,
        b=0.0,
        t=0.0
        # TODO add support for 'w' and 'h'
    )
    _check_keys_and_fill('specs', specs, SPEC_defaults)

    # Default inset key-values
    if insets:
        INSET_defaults = dict(
            cell=(1, 1),
            is_3d=False,
            l=0.0,
            w='to_end',
            b=0.0,
            h='to_end'
        )
        _check_keys_and_fill('insets', insets, INSET_defaults)

    # Set width & height of each subplot cell (excluding padding)
    width = (1. - horizontal_spacing * (cols - 1)) / cols
    height = (1. - vertical_spacing * (rows - 1)) / rows

    # Built row/col sequence using 'row_dir' and 'col_dir'
    COL_DIR = START_CELL['col_dir']
    ROW_DIR = START_CELL['row_dir']
    col_seq = range(cols)[::COL_DIR]
    row_seq = range(rows)[::ROW_DIR]

    # [grid] Build subplot grid (coord tuple of cell)
    grid = [[((width + horizontal_spacing) * c,
              (height + vertical_spacing) * r)
            for c in col_seq]
            for r in row_seq]

    # [grid_ref] Initialize the grid and insets' axis-reference lists
    grid_ref = [[None for c in range(cols)] for r in range(rows)]
    insets_ref = [None for inset in range(len(insets))] if insets else None

    layout = graph_objs.Layout()  # init layout object

    # Function handling logic around 2d axis labels
    # Returns 'x{}' | 'y{}'
    def _get_label(x_or_y, r, c, cnt, shared_axes):
        # Default label (given strictly by cnt)
        label = "{x_or_y}{cnt}".format(x_or_y=x_or_y, cnt=cnt)

        if isinstance(shared_axes, bool):
            if shared_axes:
                if x_or_y == 'x':
                    label = "{x_or_y}{c}".format(x_or_y=x_or_y, c=c+1)
                if x_or_y == 'y':
                    label = "{x_or_y}{r}".format(x_or_y=x_or_y, r=r+1)

        if isinstance(shared_axes, list):
            if isinstance(shared_axes[0], tuple):
                shared_axes = [shared_axes]  # TODO put this elsewhere
            for shared_axis in shared_axes:
                if (r+1, c+1) in shared_axis:
                    label = {
                        'x': "x{0}".format(shared_axis[0][1]),
                        'y': "y{0}".format(shared_axis[0][0])
                    }[x_or_y]

        return label

    # Row in grid of anchor row if shared_xaxes=True
    ANCHOR_ROW = 0 if ROW_DIR > 0 else rows - 1

    # Function handling logic around 2d axis anchors
    # Return 'x{}' | 'y{}' | 'free' | False
    def _get_anchors(r, c, x_cnt, y_cnt, shared_xaxes, shared_yaxes):
        # Default anchors (give strictly by cnt)
        x_anchor = "y{y_cnt}".format(y_cnt=y_cnt)
        y_anchor = "x{x_cnt}".format(x_cnt=x_cnt)

        if isinstance(shared_xaxes, bool):
            if shared_xaxes:
                if r != ANCHOR_ROW:
                    x_anchor = False
                    y_anchor = 'free'
                    if shared_yaxes and c != 0:  # TODO covers all cases?
                        y_anchor = False
                    return x_anchor, y_anchor

        elif isinstance(shared_xaxes, list):
            if isinstance(shared_xaxes[0], tuple):
                shared_xaxes = [shared_xaxes]  # TODO put this elsewhere
            for shared_xaxis in shared_xaxes:
                if (r+1, c+1) in shared_xaxis[1:]:
                    x_anchor = False
                    y_anchor = 'free'  # TODO covers all cases?

        if isinstance(shared_yaxes, bool):
            if shared_yaxes:
                if c != 0:
                    y_anchor = False
                    x_anchor = 'free'
                    if shared_xaxes and r != ANCHOR_ROW:  # TODO all cases?
                        x_anchor = False
                    return x_anchor, y_anchor

        elif isinstance(shared_yaxes, list):
            if isinstance(shared_yaxes[0], tuple):
                shared_yaxes = [shared_yaxes]  # TODO put this elsewhere
            for shared_yaxis in shared_yaxes:
                if (r+1, c+1) in shared_yaxis[1:]:
                    y_anchor = False
                    x_anchor = 'free'  # TODO covers all cases?

        return x_anchor, y_anchor

    # Function pasting x/y domains in layout object (2d case)
    def _add_domain(layout, x_or_y, label, domain, anchor, position):
        name = label[0] + 'axis' + label[1:]
        graph_obj = '{X_or_Y}Axis'.format(X_or_Y=x_or_y.upper())
        axis = getattr(graph_objs, graph_obj)(domain=domain)
        if anchor:
            axis['anchor'] = anchor
        if isinstance(position, float):
            axis['position'] = position
        layout[name] = axis

    # Function pasting x/y domains in layout object (3d case)
    def _add_domain_is_3d(layout, s_label, x_domain, y_domain):
        scene = graph_objs.Scene(domain={'x': x_domain, 'y': y_domain})
        layout[s_label] = scene

    x_cnt = y_cnt = s_cnt = 1  # subplot axis/scene counters

    # Loop through specs -- (r, c) <-> (row, col)
    for r, spec_row in enumerate(specs):
        for c, spec in enumerate(spec_row):

            if spec is None:  # skip over None cells
                continue

            c_spanned = c + spec['colspan'] - 1  # get spanned c
            r_spanned = r + spec['rowspan'] - 1  # get spanned r

            # Throw exception if 'colspan' | 'rowspan' is too large for grid
            if c_spanned >= cols:
                raise Exception("Some 'colspan' value is too large for "
                                "this subplot grid.")
            if r_spanned >= rows:
                raise Exception("Some 'rowspan' value is too large for "
                                "this subplot grid.")

            # Get x domain using grid and colspan
            x_s = grid[r][c][0] + spec['l']
            x_e = grid[r][c_spanned][0] + width - spec['r']
            x_domain = [x_s, x_e]

            # Get y domain (dep. on row_dir) using grid & r_spanned
            if ROW_DIR > 0:
                y_s = grid[r][c][1] + spec['b']
                y_e = grid[r_spanned][c][1] + height - spec['t']
            else:
                y_s = grid[r_spanned][c][1] + spec['b']
                y_e = grid[r][c][1] + height - spec['t']
            y_domain = [y_s, y_e]

            if spec['is_3d']:

                # Add scene to layout
                s_label = 'scene{0}'.format(s_cnt)
                _add_domain_is_3d(layout, s_label, x_domain, y_domain)
                grid_ref[r][c] = (s_label, )
                s_cnt += 1

            else:

                # Get axis label and anchor
                x_label = _get_label('x', r, c, x_cnt, shared_xaxes)
                y_label = _get_label('y', r, c, y_cnt, shared_yaxes)
                x_anchor, y_anchor = _get_anchors(r, c,
                                                  x_cnt, y_cnt,
                                                  shared_xaxes,
                                                  shared_yaxes)

                # Add a xaxis to layout (N.B anchor == False -> no axis)
                if x_anchor:
                    if x_anchor == 'free':
                        x_position = y_domain[0]
                    else:
                        x_position = False
                    _add_domain(layout, 'x', x_label, x_domain,
                                x_anchor, x_position)
                    x_cnt += 1

                # Add a yaxis to layout (N.B anchor == False -> no axis)
                if y_anchor:
                    if y_anchor == 'free':
                        y_position = x_domain[0]
                    else:
                        y_position = False
                    _add_domain(layout, 'y', y_label, y_domain,
                                y_anchor, y_position)
                    y_cnt += 1

                grid_ref[r][c] = (x_label, y_label)  # fill in ref

    # Loop through insets
    if insets:
        for i_inset, inset in enumerate(insets):

            r = inset['cell'][0] - 1
            c = inset['cell'][1] - 1

            # Throw exception if r | c is out of range
            if not (0 <= r < rows):
                raise Exception("Some 'cell' row value is out of range. "
                                "Note: the starting cell is (1, 1)")
            if not (0 <= c < cols):
                raise Exception("Some 'cell' col value is out of range. "
                                "Note: the starting cell is (1, 1)")

            # Get inset x domain using grid
            x_s = grid[r][c][0] + inset['l'] * width
            if inset['w'] == 'to_end':
                x_e = grid[r][c][0] + width
            else:
                x_e = x_s + inset['w'] * width
            x_domain = [x_s, x_e]

            # Get inset y domain using grid
            y_s = grid[r][c][1] + inset['b'] * height
            if inset['h'] == 'to_end':
                y_e = grid[r][c][1] + height
            else:
                y_e = y_s + inset['h'] * height
            y_domain = [y_s, y_e]

            if inset['is_3d']:

                # Add scene to layout
                s_label = 'scene{0}'.format(s_cnt)
                _add_domain_is_3d(layout, s_label, x_domain, y_domain)
                insets_ref[i_inset] = (s_label, )
                s_cnt += 1

            else:

                # Get axis label and anchor
                x_label = _get_label('x', False, False, x_cnt, False)
                y_label = _get_label('y', False, False, y_cnt, False)
                x_anchor, y_anchor = _get_anchors(r, c,
                                                  x_cnt, y_cnt,
                                                  False, False)

                # Add a xaxis to layout (N.B insets always have anchors)
                _add_domain(layout, 'x', x_label, x_domain, x_anchor, False)
                x_cnt += 1

                # Add a yayis to layout (N.B insets always have anchors)
                _add_domain(layout, 'y', y_label, y_domain, y_anchor, False)
                y_cnt += 1

                insets_ref[i_inset] = (x_label, y_label)  # fill in ref

    # [grid_str] Set the grid's string representation
    sp = "  "            # space between cell
    s_str = "[ "         # cell start string
    e_str = " ]"         # cell end string
    colspan_str = '       -'     # colspan string
    rowspan_str = '       |'     # rowspan string
    empty_str = '    (empty) '   # empty cell string

    # Init grid_str with intro message
    grid_str = "This is the format of your plot grid:\n"

    # Init tmp list of lists of strings (sorta like 'grid_ref' but w/ strings)
    _tmp = [['' for c in range(cols)] for r in range(rows)]

    # Define cell string as function of (r, c) and grid_ref
    def _get_cell_str(r, c, ref):
        return '({r},{c}) {ref}'.format(r=r+1, c=c+1, ref=','.join(ref))

    # Find max len of _cell_str, add define a padding function
    cell_len = max([len(_get_cell_str(r, c, ref))
                    for r, row_ref in enumerate(grid_ref)
                    for c, ref in enumerate(row_ref)
                    if ref]) + len(s_str) + len(e_str)

    def _pad(s, cell_len=cell_len):
        return ' ' * (cell_len - len(s))

    # Loop through specs, fill in _tmp
    for r, spec_row in enumerate(specs):
        for c, spec in enumerate(spec_row):

            ref = grid_ref[r][c]
            if ref is None:
                if _tmp[r][c] == '':
                    _tmp[r][c] = empty_str + _pad(empty_str)
                continue

            cell_str = s_str + _get_cell_str(r, c, ref)

            if spec['colspan'] > 1:
                for cc in range(1, spec['colspan']-1):
                    _tmp[r][c+cc] = colspan_str + _pad(colspan_str)
                _tmp[r][c+spec['colspan']-1] = (
                    colspan_str + _pad(colspan_str + e_str)) + e_str
            else:
                cell_str += e_str

            if spec['rowspan'] > 1:
                for rr in range(1, spec['rowspan']-1):
                    _tmp[r+rr][c] = rowspan_str + _pad(rowspan_str)
                for cc in range(spec['colspan']):
                    _tmp[r+spec['rowspan']-1][c+cc] = (
                        rowspan_str + _pad(rowspan_str))

            _tmp[r][c] = cell_str + _pad(cell_str)

    # Append grid_str using data from _tmp in the correct order
    for r in row_seq[::-1]:
        grid_str += sp.join(_tmp[r]) + '\n'

    # Append grid_str to include insets info
    if insets:
        grid_str += "\nWith insets:\n"
        for i_inset, inset in enumerate(insets):

            r = inset['cell'][0] - 1
            c = inset['cell'][1] - 1
            ref = grid_ref[r][c]

            grid_str += (
                s_str + ','.join(insets_ref[i_inset]) + e_str +
                ' over ' +
                s_str + _get_cell_str(r, c, ref) + e_str + '\n'
            )

    if print_grid:
        print(grid_str)

    fig = graph_objs.Figure(layout=layout)

    fig._grid_ref = grid_ref
    fig._grid_str = grid_str

    return fig
예제 #23
0
def create_ternary_contour(
    coordinates,
    values,
    pole_labels=["a", "b", "c"],
    width=500,
    height=500,
    ncontours=None,
    showscale=False,
    coloring=None,
    colorscale="Bluered",
    linecolor=None,
    title=None,
    interp_mode="ilr",
    showmarkers=False,
):
    """
    Ternary contour plot.

    Parameters
    ----------

    coordinates : list or ndarray
        Barycentric coordinates of shape (2, N) or (3, N) where N is the
        number of data points. The sum of the 3 coordinates is expected
        to be 1 for all data points.
    values : array-like
        Data points of field to be represented as contours.
    pole_labels : str, default ['a', 'b', 'c']
        Names of the three poles of the triangle.
    width : int
        Figure width.
    height : int
        Figure height.
    ncontours : int or None
        Number of contours to display (determined automatically if None).
    showscale : bool, default False
        If True, a colorbar showing the color scale is displayed.
    coloring : None or 'lines'
        How to display contour. Filled contours if None, lines if ``lines``.
    colorscale : None or str (Plotly colormap)
        colorscale of the contours.
    linecolor : None or rgb color
        Color used for lines. ``colorscale`` has to be set to None, otherwise
        line colors are determined from ``colorscale``.
    title : str or None
        Title of ternary plot
    interp_mode : 'ilr' (default) or 'cartesian'
        Defines how data are interpolated to compute contours. If 'irl',
        ILR (Isometric Log-Ratio) of compositional data is performed. If
        'cartesian', contours are determined in Cartesian space.
    showmarkers : bool, default False
        If True, markers corresponding to input compositional points are
        superimposed on contours, using the same colorscale.

    Examples
    ========

    Example 1: ternary contour plot with filled contours

    >>> import plotly.figure_factory as ff
    >>> import numpy as np
    >>> # Define coordinates
    >>> a, b = np.mgrid[0:1:20j, 0:1:20j]
    >>> mask = a + b <= 1
    >>> a = a[mask].ravel()
    >>> b = b[mask].ravel()
    >>> c = 1 - a - b
    >>> # Values to be displayed as contours
    >>> z = a * b * c
    >>> fig = ff.create_ternarycontour(np.stack((a, b, c)), z)
    >>> fig.show()

    It is also possible to give only two barycentric coordinates for each
    point, since the sum of the three coordinates is one:

    >>> fig = ff.create_ternarycontour(np.stack((a, b)), z)


    Example 2: ternary contour plot with line contours

    >>> fig = ff.create_ternarycontour(np.stack((a, b, c)), z, coloring='lines')

    Example 3: customize number of contours

    >>> fig = ff.create_ternarycontour(np.stack((a, b, c)), z, ncontours=8)

    Example 4: superimpose contour plot and original data as markers

    >>> fig = ff.create_ternarycontour(np.stack((a, b, c)), z, coloring='lines',
                                   showmarkers=True)

    Example 5: customize title and pole labels

    >>> fig = ff.create_ternarycontour(np.stack((a, b, c)), z,
    ...                               title='Ternary plot',
    ...                               pole_labels=['clay', 'quartz', 'fledspar'])
    """
    if scipy_interp is None:
        raise ImportError("""\
    The create_ternary_contour figure factory requires the scipy package""")
    sk_measure = optional_imports.get_module("skimage")
    if sk_measure is None:
        raise ImportError("""\
    The create_ternary_contour figure factory requires the scikit-image
    package""")
    if colorscale is None:
        showscale = False
    if ncontours is None:
        ncontours = 5
    coordinates = _prepare_barycentric_coord(coordinates)
    v_min, v_max = values.min(), values.max()
    grid_z, gr_x, gr_y = _compute_grid(coordinates,
                                       values,
                                       interp_mode=interp_mode)

    layout = _ternary_layout(pole_labels=pole_labels,
                             width=width,
                             height=height,
                             title=title)

    contour_trace, discrete_cm = _contour_trace(
        gr_x,
        gr_y,
        grid_z,
        ncontours=ncontours,
        colorscale=colorscale,
        linecolor=linecolor,
        interp_mode=interp_mode,
        coloring=coloring,
        v_min=v_min,
        v_max=v_max,
    )

    fig = go.Figure(data=contour_trace, layout=layout)

    opacity = 1 if showmarkers else 0
    a, b, c = coordinates
    hovertemplate = (pole_labels[0] + ": %{a:.3f}<br>" + pole_labels[1] +
                     ": %{b:.3f}<br>" + pole_labels[2] + ": %{c:.3f}<br>"
                     "z: %{marker.color:.3f}<extra></extra>")

    fig.add_scatterternary(
        a=a,
        b=b,
        c=c,
        mode="markers",
        marker={
            "color": values,
            "colorscale": colorscale,
            "line": {
                "color": "rgb(120, 120, 120)",
                "width": int(coloring != "lines")
            },
        },
        opacity=opacity,
        hovertemplate=hovertemplate,
    )
    if showscale:
        if not showmarkers:
            colorscale = discrete_cm
        colorbar = dict({
            "type": "scatterternary",
            "a": [None],
            "b": [None],
            "c": [None],
            "marker": {
                "cmin": values.min(),
                "cmax": values.max(),
                "colorscale": colorscale,
                "showscale": True,
            },
            "mode": "markers",
        })
        fig.add_trace(colorbar)

    return fig
예제 #24
0
def get_subplots(rows=1,
                 columns=1,
                 horizontal_spacing=0.1,
                 vertical_spacing=0.15,
                 print_grid=False):
    """Return a dictionary instance with the subplots set in 'layout'.

    Example 1:
        # stack two subplots vertically
        fig = tools.get_subplots(rows=2)
        fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x1', yaxis='y1')]
        fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')]

    Example 2:
        # print out string showing the subplot grid you've put in the layout
        fig = tools.get_subplots(rows=3, columns=2, print_grid=True)

    key (types, default=default):
        description.

    rows (int, default=1):
        Number of rows, evenly spaced vertically on the figure.

    columns (int, default=1):
        Number of columns, evenly spaced horizontally on the figure.

    horizontal_spacing (float in [0,1], default=0.1):
        Space between subplot columns. Applied to all columns.

    vertical_spacing (float in [0,1], default=0.05):
        Space between subplot rows. Applied to all rows.

    print_grid (True | False, default=False):
        If True, prints a tab-delimited string representation of your plot grid.

    """
    fig = dict(layout=graph_objs.Layout())  # will return this at the end
    plot_width = (1 - horizontal_spacing * (columns - 1)) / columns
    plot_height = (1 - vertical_spacing * (rows - 1)) / rows
    plot_num = 0
    for rrr in range(rows):
        for ccc in range(columns):
            xaxis_name = 'xaxis{0}'.format(plot_num + 1)
            x_anchor = 'y{0}'.format(plot_num + 1)
            x_start = (plot_width + horizontal_spacing) * ccc
            x_end = x_start + plot_width

            yaxis_name = 'yaxis{0}'.format(plot_num + 1)
            y_anchor = 'x{0}'.format(plot_num + 1)
            y_start = (plot_height + vertical_spacing) * rrr
            y_end = y_start + plot_height

            xaxis = graph_objs.XAxis(domain=[x_start, x_end], anchor=x_anchor)
            fig['layout'][xaxis_name] = xaxis
            yaxis = graph_objs.YAxis(domain=[y_start, y_end], anchor=y_anchor)
            fig['layout'][yaxis_name] = yaxis
            plot_num += 1
    if print_grid:
        print("This is the format of your plot grid!")
        grid_string = ""
        plot = 1
        for rrr in range(rows):
            grid_line = ""
            for ccc in range(columns):
                grid_line += "[{0}]\t".format(plot)
                plot += 1
            grid_string = grid_line + '\n' + grid_string
        print(grid_string)
    return graph_objs.Figure(fig)  # forces us to validate what we just did...
예제 #25
0
def create_dendrogram(X,
                      orientation="bottom",
                      labels=None,
                      colorscale=None,
                      distfun=None,
                      linkagefun=lambda x: sch.linkage(x, 'complete'),
                      hovertext=None,
                      color_threshold=None):
    """
    BETA function that returns a dendrogram Plotly figure object.

    :param (ndarray) X: Matrix of observations as array of arrays
    :param (str) orientation: 'top', 'right', 'bottom', or 'left'
    :param (list) labels: List of axis category labels(observation labels)
    :param (list) colorscale: Optional colorscale for dendrogram tree
    :param (function) distfun: Function to compute the pairwise distance from
                               the observations
    :param (function) linkagefun: Function to compute the linkage matrix from
                               the pairwise distances
    :param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram
                               clusters
    :param (double) color_threshold: Value at which the separation of clusters will be made

    Example 1: Simple bottom oriented dendrogram
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_dendrogram

    import numpy as np

    X = np.random.rand(10,10)
    dendro = create_dendrogram(X)
    plot_url = py.plot(dendro, filename='simple-dendrogram')

    ```

    Example 2: Dendrogram to put on the left of the heatmap
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_dendrogram

    import numpy as np

    X = np.random.rand(5,5)
    names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
    dendro = create_dendrogram(X, orientation='right', labels=names)
    dendro['layout'].update({'width':700, 'height':500})

    py.iplot(dendro, filename='vertical-dendrogram')
    ```

    Example 3: Dendrogram with Pandas
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_dendrogram

    import numpy as np
    import pandas as pd

    Index= ['A','B','C','D','E','F','G','H','I','J']
    df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
    fig = create_dendrogram(df, labels=Index)
    url = py.plot(fig, filename='pandas-dendrogram')
    ```
    """
    if not scp or not scs or not sch:
        raise ImportError("FigureFactory.create_dendrogram requires scipy, \
                            scipy.spatial and scipy.hierarchy")

    s = X.shape
    if len(s) != 2:
        exceptions.PlotlyError("X should be 2-dimensional array.")

    if distfun is None:
        distfun = scs.distance.pdist

    dendrogram = _Dendrogram(X,
                             orientation,
                             labels,
                             colorscale,
                             distfun=distfun,
                             linkagefun=linkagefun,
                             hovertext=hovertext,
                             color_threshold=color_threshold)

    return graph_objs.Figure(data=dendrogram.data, layout=dendrogram.layout)
예제 #26
0
def create_dendrogram(
    X,
    orientation="bottom",
    labels=None,
    colorscale=None,
    distfun=None,
    linkagefun=lambda x: sch.linkage(x, "complete"),
    hovertext=None,
    color_threshold=None,
):
    """
    Function that returns a dendrogram Plotly figure object.

    See also https://dash.plot.ly/dash-bio/clustergram.

    :param (ndarray) X: Matrix of observations as array of arrays
    :param (str) orientation: 'top', 'right', 'bottom', or 'left'
    :param (list) labels: List of axis category labels(observation labels)
    :param (list) colorscale: Optional colorscale for dendrogram tree
    :param (function) distfun: Function to compute the pairwise distance from
                               the observations
    :param (function) linkagefun: Function to compute the linkage matrix from
                               the pairwise distances
    :param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram
                               clusters
    :param (double) color_threshold: Value at which the separation of clusters will be made

    Example 1: Simple bottom oriented dendrogram

    >>> from plotly.figure_factory import create_dendrogram

    >>> import numpy as np

    >>> X = np.random.rand(10,10)
    >>> fig = create_dendrogram(X)
    >>> fig.show()

    Example 2: Dendrogram to put on the left of the heatmap
    
    >>> from plotly.figure_factory import create_dendrogram

    >>> import numpy as np

    >>> X = np.random.rand(5,5)
    >>> names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
    >>> dendro = create_dendrogram(X, orientation='right', labels=names)
    >>> dendro.update_layout({'width':700, 'height':500}) # doctest: +SKIP
    >>> dendro.show()

    Example 3: Dendrogram with Pandas
    
    >>> from plotly.figure_factory import create_dendrogram

    >>> import numpy as np
    >>> import pandas as pd

    >>> Index= ['A','B','C','D','E','F','G','H','I','J']
    >>> df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
    >>> fig = create_dendrogram(df, labels=Index)
    >>> fig.show()
    """
    if not scp or not scs or not sch:
        raise ImportError("FigureFactory.create_dendrogram requires scipy, \
                            scipy.spatial and scipy.hierarchy")

    s = X.shape
    if len(s) != 2:
        exceptions.PlotlyError("X should be 2-dimensional array.")

    if distfun is None:
        distfun = scs.distance.pdist

    dendrogram = _Dendrogram(
        X,
        orientation,
        labels,
        colorscale,
        distfun=distfun,
        linkagefun=linkagefun,
        hovertext=hovertext,
        color_threshold=color_threshold,
    )

    return graph_objs.Figure(data=dendrogram.data, layout=dendrogram.layout)
예제 #27
0
def create_quiver(x,
                  y,
                  u,
                  v,
                  scale=0.1,
                  arrow_scale=0.3,
                  angle=math.pi / 9,
                  scaleratio=None,
                  **kwargs):
    """
    Returns data for a quiver plot.

    :param (list|ndarray) x: x coordinates of the arrow locations
    :param (list|ndarray) y: y coordinates of the arrow locations
    :param (list|ndarray) u: x components of the arrow vectors
    :param (list|ndarray) v: y components of the arrow vectors
    :param (float in [0,1]) scale: scales size of the arrows(ideally to
        avoid overlap). Default = .1
    :param (float in [0,1]) arrow_scale: value multiplied to length of barb
        to get length of arrowhead. Default = .3
    :param (angle in radians) angle: angle of arrowhead. Default = pi/9
    :param (positive float) scaleratio: the ratio between the scale of the y-axis
        and the scale of the x-axis (scale_y / scale_x). Default = None, the
        scale ratio is not fixed.
    :param kwargs: kwargs passed through plotly.graph_objs.Scatter
        for more information on valid kwargs call
        help(plotly.graph_objs.Scatter)

    :rtype (dict): returns a representation of quiver figure.

    Example 1: Trivial Quiver
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_quiver

    import math

    # 1 Arrow from (0,0) to (1,1)
    fig = create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1)

    py.plot(fig, filename='quiver')
    ```

    Example 2: Quiver plot using meshgrid
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_quiver

    import numpy as np
    import math

    # Add data
    x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2))
    u = np.cos(x)*y
    v = np.sin(x)*y

    #Create quiver
    fig = create_quiver(x, y, u, v)

    # Plot
    py.plot(fig, filename='quiver')
    ```

    Example 3: Styling the quiver plot
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_quiver
    import numpy as np
    import math

    # Add data
    x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5),
                       np.arange(-math.pi, math.pi, .5))
    u = np.cos(x)*y
    v = np.sin(x)*y

    # Create quiver
    fig = create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6,
                        name='Wind Velocity', line=dict(width=1))

    # Add title to layout
    fig['layout'].update(title='Quiver Plot')

    # Plot
    py.plot(fig, filename='quiver')
    ```

    Example 4: Forcing a fix scale ratio to maintain the arrow length
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_quiver

    import numpy as np

    # Add data
    x,y = np.meshgrid(np.arange(0.5, 3.5, .5), np.arange(0.5, 4.5, .5))
    u = x
    v = y
    angle = np.arctan(v / u)
    norm = 0.25
    u = norm * np.cos(angle)
    v = norm * np.sin(angle)

    # Create quiver with a fix scale ratio
    fig = create_quiver(x, y, u, v, scale = 1, scaleratio = 0.5)

    # Plot
    py.plot(fig, filename='quiver')
    ```
    """
    utils.validate_equal_length(x, y, u, v)
    utils.validate_positive_scalars(arrow_scale=arrow_scale, scale=scale)

    if scaleratio is None:
        quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle)
    else:
        quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle, scaleratio)

    barb_x, barb_y = quiver_obj.get_barbs()
    arrow_x, arrow_y = quiver_obj.get_quiver_arrows()

    quiver_plot = graph_objs.Scatter(x=barb_x + arrow_x,
                                     y=barb_y + arrow_y,
                                     mode="lines",
                                     **kwargs)

    data = [quiver_plot]

    if scaleratio is None:
        layout = graph_objs.Layout(hovermode="closest")
    else:
        layout = graph_objs.Layout(hovermode="closest",
                                   yaxis=dict(scaleratio=scaleratio,
                                              scaleanchor="x"))

    return graph_objs.Figure(data=data, layout=layout)
예제 #28
0
def create_2d_density(
    x,
    y,
    colorscale="Earth",
    ncontours=20,
    hist_color=(0, 0, 0.5),
    point_color=(0, 0, 0.5),
    point_size=2,
    title="2D Density Plot",
    height=600,
    width=600,
):
    """
    **deprecated**, use instead
    :func:`plotly.express.density_heatmap`.

    :param (list|array) x: x-axis data for plot generation
    :param (list|array) y: y-axis data for plot generation
    :param (str|tuple|list) colorscale: either a plotly scale name, an rgb
        or hex color, a color tuple or a list or tuple of colors. An rgb
        color is of the form 'rgb(x, y, z)' where x, y, z belong to the
        interval [0, 255] and a color tuple is a tuple of the form
        (a, b, c) where a, b and c belong to [0, 1]. If colormap is a
        list, it must contain the valid color types aforementioned as its
        members.
    :param (int) ncontours: the number of 2D contours to draw on the plot
    :param (str) hist_color: the color of the plotted histograms
    :param (str) point_color: the color of the scatter points
    :param (str) point_size: the color of the scatter points
    :param (str) title: set the title for the plot
    :param (float) height: the height of the chart
    :param (float) width: the width of the chart

    Examples
    --------

    Example 1: Simple 2D Density Plot

    >>> from plotly.figure_factory import create_2d_density
    >>> import numpy as np

    >>> # Make data points
    >>> t = np.linspace(-1,1.2,2000)
    >>> x = (t**3)+(0.3*np.random.randn(2000))
    >>> y = (t**6)+(0.3*np.random.randn(2000))

    >>> # Create a figure
    >>> fig = create_2d_density(x, y)

    >>> # Plot the data
    >>> fig.show()

    Example 2: Using Parameters

    >>> from plotly.figure_factory import create_2d_density

    >>> import numpy as np

    >>> # Make data points
    >>> t = np.linspace(-1,1.2,2000)
    >>> x = (t**3)+(0.3*np.random.randn(2000))
    >>> y = (t**6)+(0.3*np.random.randn(2000))

    >>> # Create custom colorscale
    >>> colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)',
    ...              (1, 1, 0.2), (0.98,0.98,0.98)]

    >>> # Create a figure
    >>> fig = create_2d_density(x, y, colorscale=colorscale,
    ...       hist_color='rgb(255, 237, 222)', point_size=3)

    >>> # Plot the data
    >>> fig.show()
    """

    # validate x and y are filled with numbers only
    for array in [x, y]:
        if not all(isinstance(element, Number) for element in array):
            raise plotly.exceptions.PlotlyError(
                "All elements of your 'x' and 'y' lists must be numbers.")

    # validate x and y are the same length
    if len(x) != len(y):
        raise plotly.exceptions.PlotlyError(
            "Both lists 'x' and 'y' must be the same length.")

    colorscale = clrs.validate_colors(colorscale, "rgb")
    colorscale = make_linear_colorscale(colorscale)

    # validate hist_color and point_color
    hist_color = clrs.validate_colors(hist_color, "rgb")
    point_color = clrs.validate_colors(point_color, "rgb")

    trace1 = graph_objs.Scatter(
        x=x,
        y=y,
        mode="markers",
        name="points",
        marker=dict(color=point_color[0], size=point_size, opacity=0.4),
    )
    trace2 = graph_objs.Histogram2dContour(
        x=x,
        y=y,
        name="density",
        ncontours=ncontours,
        colorscale=colorscale,
        reversescale=True,
        showscale=False,
    )
    trace3 = graph_objs.Histogram(x=x,
                                  name="x density",
                                  marker=dict(color=hist_color[0]),
                                  yaxis="y2")
    trace4 = graph_objs.Histogram(y=y,
                                  name="y density",
                                  marker=dict(color=hist_color[0]),
                                  xaxis="x2")
    data = [trace1, trace2, trace3, trace4]

    layout = graph_objs.Layout(
        showlegend=False,
        autosize=False,
        title=title,
        height=height,
        width=width,
        xaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False),
        yaxis=dict(domain=[0, 0.85], showgrid=False, zeroline=False),
        margin=dict(t=50),
        hovermode="closest",
        bargap=0,
        xaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False),
        yaxis2=dict(domain=[0.85, 1], showgrid=False, zeroline=False),
    )

    fig = graph_objs.Figure(data=data, layout=layout)
    return fig
예제 #29
0
def create_violin(data, data_header=None, group_header=None, colors=None,
                  use_colorscale=False, group_stats=None, rugplot=True,
                  sort=False, height=450, width=600,
                  title='Violin and Rug Plot'):
    """
    Returns figure for a violin plot

    :param (list|array) data: accepts either a list of numerical values,
        a list of dictionaries all with identical keys and at least one
        column of numeric values, or a pandas dataframe with at least one
        column of numbers.
    :param (str) data_header: the header of the data column to be used
        from an inputted pandas dataframe. Not applicable if 'data' is
        a list of numeric values.
    :param (str) group_header: applicable if grouping data by a variable.
        'group_header' must be set to the name of the grouping variable.
    :param (str|tuple|list|dict) colors: either a plotly scale name,
        an rgb or hex color, a color tuple, a list of colors or a
        dictionary. An rgb color is of the form 'rgb(x, y, z)' where
        x, y and z belong to the interval [0, 255] and a color tuple is a
        tuple of the form (a, b, c) where a, b and c belong to [0, 1].
        If colors is a list, it must contain valid color types as its
        members.
    :param (bool) use_colorscale: only applicable if grouping by another
        variable. Will implement a colorscale based on the first 2 colors
        of param colors. This means colors must be a list with at least 2
        colors in it (Plotly colorscales are accepted since they map to a
        list of two rgb colors). Default = False
    :param (dict) group_stats: a dictioanry where each key is a unique
        value from the group_header column in data. Each value must be a
        number and will be used to color the violin plots if a colorscale
        is being used.
    :param (bool) rugplot: determines if a rugplot is draw on violin plot.
        Default = True
    :param (bool) sort: determines if violins are sorted
        alphabetically (True) or by input order (False). Default = False
    :param (float) height: the height of the violin plot.
    :param (float) width: the width of the violin plot.
    :param (str) title: the title of the violin plot.

    Example 1: Single Violin Plot
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_violin
    from plotly.graph_objs import graph_objs

    import numpy as np
    from scipy import stats

    # create list of random values
    data_list = np.random.randn(100)
    data_list.tolist()

    # create violin fig
    fig = create_violin(data_list, colors='#604d9e')

    # plot
    py.iplot(fig, filename='Violin Plot')
    ```

    Example 2: Multiple Violin Plots with Qualitative Coloring
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_violin
    from plotly.graph_objs import graph_objs

    import numpy as np
    import pandas as pd
    from scipy import stats

    # create dataframe
    np.random.seed(619517)
    Nr=250
    y = np.random.randn(Nr)
    gr = np.random.choice(list("ABCDE"), Nr)
    norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]

    for i, letter in enumerate("ABCDE"):
        y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
    df = pd.DataFrame(dict(Score=y, Group=gr))

    # create violin fig
    fig = create_violin(df, data_header='Score', group_header='Group',
                        sort=True, height=600, width=1000)

    # plot
    py.iplot(fig, filename='Violin Plot with Coloring')
    ```

    Example 3: Violin Plots with Colorscale
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_violin
    from plotly.graph_objs import graph_objs

    import numpy as np
    import pandas as pd
    from scipy import stats

    # create dataframe
    np.random.seed(619517)
    Nr=250
    y = np.random.randn(Nr)
    gr = np.random.choice(list("ABCDE"), Nr)
    norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]

    for i, letter in enumerate("ABCDE"):
        y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
    df = pd.DataFrame(dict(Score=y, Group=gr))

    # define header params
    data_header = 'Score'
    group_header = 'Group'

    # make groupby object with pandas
    group_stats = {}
    groupby_data = df.groupby([group_header])

    for group in "ABCDE":
        data_from_group = groupby_data.get_group(group)[data_header]
        # take a stat of the grouped data
        stat = np.median(data_from_group)
        # add to dictionary
        group_stats[group] = stat

    # create violin fig
    fig = create_violin(df, data_header='Score', group_header='Group',
                        height=600, width=1000, use_colorscale=True,
                        group_stats=group_stats)

    # plot
    py.iplot(fig, filename='Violin Plot with Colorscale')
    ```
    """

    # Validate colors
    if isinstance(colors, dict):
        valid_colors = utils.validate_colors_dict(colors, 'rgb')
    else:
        valid_colors = utils.validate_colors(colors, 'rgb')

    # validate data and choose plot type
    if group_header is None:
        if isinstance(data, list):
            if len(data) <= 0:
                raise exceptions.PlotlyError("If data is a list, it must be "
                                             "nonempty and contain either "
                                             "numbers or dictionaries.")

            if not all(isinstance(element, Number) for element in data):
                raise exceptions.PlotlyError("If data is a list, it must "
                                             "contain only numbers.")

        if pd and isinstance(data, pd.core.frame.DataFrame):
            if data_header is None:
                raise exceptions.PlotlyError("data_header must be the "
                                             "column name with the "
                                             "desired numeric data for "
                                             "the violin plot.")

            data = data[data_header].values.tolist()

        # call the plotting functions
        plot_data, plot_xrange = violinplot(data, fillcolor=valid_colors[0],
                                            rugplot=rugplot)

        layout = graph_objs.Layout(
            title=title,
            autosize=True,
            font=graph_objs.Font(size=11),
            height=height,
            showlegend=False,
            width=width,
            xaxis=make_XAxis('', plot_xrange),
            yaxis=make_YAxis(''),
            hovermode='closest'
        )
        layout['yaxis'].update(dict(showline=False,
                                    showticklabels=False,
                                    ticks=''))

        fig = graph_objs.Figure(data=graph_objs.Data(plot_data),
                                layout=layout)

        return fig

    else:
        if not isinstance(data, pd.core.frame.DataFrame):
            raise exceptions.PlotlyError("Error. You must use a pandas "
                                         "DataFrame if you are using a "
                                         "group header.")

        if data_header is None:
            raise exceptions.PlotlyError("data_header must be the column "
                                         "name with the desired numeric "
                                         "data for the violin plot.")

        if use_colorscale is False:
            if isinstance(valid_colors, dict):
                # validate colors dict choice below
                fig = violin_dict(
                    data, data_header, group_header, valid_colors,
                    use_colorscale, group_stats, rugplot, sort,
                    height, width, title
                )
                return fig
            else:
                fig = violin_no_colorscale(
                    data, data_header, group_header, valid_colors,
                    use_colorscale, group_stats, rugplot, sort,
                    height, width, title
                )
                return fig
        else:
            if isinstance(valid_colors, dict):
                raise exceptions.PlotlyError("The colors param cannot be "
                                             "a dictionary if you are "
                                             "using a colorscale.")

            if len(valid_colors) < 2:
                raise exceptions.PlotlyError("colors must be a list with "
                                             "at least 2 colors. A "
                                             "Plotly scale is allowed.")

            if not isinstance(group_stats, dict):
                raise exceptions.PlotlyError("Your group_stats param "
                                             "must be a dictionary.")

            fig = violin_colorscale(
                data, data_header, group_header, valid_colors,
                use_colorscale, group_stats, rugplot, sort, height,
                width, title
            )
            return fig
예제 #30
0
def create_distplot(
    hist_data,
    group_labels,
    bin_size=1.0,
    curve_type="kde",
    colors=None,
    rug_text=None,
    histnorm=DEFAULT_HISTNORM,
    show_hist=True,
    show_curve=True,
    show_rug=True,
):
    """
    BETA function that creates a distplot similar to seaborn.distplot

    The distplot can be composed of all or any combination of the following
    3 components: (1) histogram, (2) curve: (a) kernel density estimation
    or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
    (from multiple datasets) can be created in the same plot.

    :param (list[list]) hist_data: Use list of lists to plot multiple data
        sets on the same plot.
    :param (list[str]) group_labels: Names for each data set.
    :param (list[float]|float) bin_size: Size of histogram bins.
        Default = 1.
    :param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
    :param (str) histnorm: 'probability density' or 'probability'
        Default = 'probability density'
    :param (bool) show_hist: Add histogram to distplot? Default = True
    :param (bool) show_curve: Add curve to distplot? Default = True
    :param (bool) show_rug: Add rug to distplot? Default = True
    :param (list[str]) colors: Colors for traces.
    :param (list[list]) rug_text: Hovertext values for rug_plot,
    :return (dict): Representation of a distplot figure.

    Example 1: Simple distplot of 1 data set
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_distplot

    hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
                  3.5, 4.1, 4.4, 4.5, 4.5,
                  5.0, 5.0, 5.2, 5.5, 5.5,
                  5.5, 5.5, 5.5, 6.1, 7.0]]

    group_labels = ['distplot example']

    fig = create_distplot(hist_data, group_labels)

    url = py.plot(fig, filename='Simple distplot', validate=False)
    ```

    Example 2: Two data sets and added rug text
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_distplot

    # Add histogram data
    hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
               -0.9, -0.07, 1.95, 0.9, -0.2,
               -0.5, 0.3, 0.4, -0.37, 0.6]
    hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
               1.0, 0.8, 1.7, 0.5, 0.8,
               -0.3, 1.2, 0.56, 0.3, 2.2]

    # Group data together
    hist_data = [hist1_x, hist2_x]

    group_labels = ['2012', '2013']

    # Add text
    rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
          'f1', 'g1', 'h1', 'i1', 'j1',
          'k1', 'l1', 'm1', 'n1', 'o1']

    rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
          'f2', 'g2', 'h2', 'i2', 'j2',
          'k2', 'l2', 'm2', 'n2', 'o2']

    # Group text together
    rug_text_all = [rug_text_1, rug_text_2]

    # Create distplot
    fig = create_distplot(
        hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)

    # Add title
    fig['layout'].update(title='Dist Plot')

    # Plot!
    url = py.plot(fig, filename='Distplot with rug text', validate=False)
    ```

    Example 3: Plot with normal curve and hide rug plot
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_distplot
    import numpy as np

    x1 = np.random.randn(190)
    x2 = np.random.randn(200)+1
    x3 = np.random.randn(200)-1
    x4 = np.random.randn(210)+2

    hist_data = [x1, x2, x3, x4]
    group_labels = ['2012', '2013', '2014', '2015']

    fig = create_distplot(
        hist_data, group_labels, curve_type='normal',
        show_rug=False, bin_size=.4)

    url = py.plot(fig, filename='hist and normal curve', validate=False)

    Example 4: Distplot with Pandas
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_distplot
    import numpy as np
    import pandas as pd

    df = pd.DataFrame({'2012': np.random.randn(200),
                       '2013': np.random.randn(200)+1})
    py.iplot(create_distplot([df[c] for c in df.columns], df.columns),
                             filename='examples/distplot with pandas',
                             validate=False)
    ```
    """
    if colors is None:
        colors = []
    if rug_text is None:
        rug_text = []

    validate_distplot(hist_data, curve_type)
    utils.validate_equal_length(hist_data, group_labels)

    if isinstance(bin_size, (float, int)):
        bin_size = [bin_size] * len(hist_data)

    hist = _Distplot(
        hist_data,
        histnorm,
        group_labels,
        bin_size,
        curve_type,
        colors,
        rug_text,
        show_hist,
        show_curve,
    ).make_hist()

    if curve_type == "normal":
        curve = _Distplot(
            hist_data,
            histnorm,
            group_labels,
            bin_size,
            curve_type,
            colors,
            rug_text,
            show_hist,
            show_curve,
        ).make_normal()
    else:
        curve = _Distplot(
            hist_data,
            histnorm,
            group_labels,
            bin_size,
            curve_type,
            colors,
            rug_text,
            show_hist,
            show_curve,
        ).make_kde()

    rug = _Distplot(
        hist_data,
        histnorm,
        group_labels,
        bin_size,
        curve_type,
        colors,
        rug_text,
        show_hist,
        show_curve,
    ).make_rug()

    data = []
    if show_hist:
        data.append(hist)
    if show_curve:
        data.append(curve)
    if show_rug:
        data.append(rug)
        layout = graph_objs.Layout(
            barmode="overlay",
            hovermode="closest",
            legend=dict(traceorder="reversed"),
            xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
            yaxis1=dict(domain=[0.35, 1], anchor="free", position=0.0),
            yaxis2=dict(domain=[0, 0.25],
                        anchor="x1",
                        dtick=1,
                        showticklabels=False),
        )
    else:
        layout = graph_objs.Layout(
            barmode="overlay",
            hovermode="closest",
            legend=dict(traceorder="reversed"),
            xaxis1=dict(domain=[0.0, 1.0], anchor="y2", zeroline=False),
            yaxis1=dict(domain=[0.0, 1], anchor="free", position=0.0),
        )

    data = sum(data, [])
    return graph_objs.Figure(data=data, layout=layout)