Esempio n. 1
0
def ridge_plot(d, value, groupby, step=30, overlap=0.8, sort=None):
    return (
        alt.Chart(d)
        .transform_joinaggregate(mean_value=f"mean({value})", groupby=[groupby])
        .transform_bin(["bin_max", "bin_min"], value)
        .transform_aggregate(
            value="count()", groupby=[groupby, "mean_value", "bin_min", "bin_max"]
        )
        .transform_impute(
            impute="value", groupby=[groupby, "mean_value"], key="bin_min", value=0
        )
        .mark_area(
            interpolate="monotone", fillOpacity=0.8, stroke="lightgray", strokeWidth=0.5
        )
        .encode(
            alt.X("bin_min:Q", bin="binned", title='activation', axis=alt.Axis(format='%', labelFlush=False)),
            alt.Y("value:Q", scale=alt.Scale(range=[step, -step * overlap]), axis=None),
            alt.Fill(
                "mean_value:Q",
                legend=None,
                scale=alt.Scale(
                    domain=[d[value].max(), d[value].min()], scheme="redyellowblue"
                ),
            ),
            alt.Row(
                f"{groupby}:N",
                title=None,
                sort=alt.SortArray(sort) if sort else None,
                header=alt.Header(labelAngle=0, labelAlign="right", format="%B"),
            ),
        )
        .properties(bounds="flush", height=step)
        .configure_facet(spacing=0)
        .configure_view(stroke=None)
    )
Esempio n. 2
0
def test_shorthand_typecodes(mydata):
    opts = {
        "xvar": "name",
        "yvar": "amount:O",
        "fillvar": "max(amount)",
        "sizevar": "mean(amount):O",
    }
    cg = ChannelGroup(opts, mydata)

    assert cg == {
        "x":
        alt.X(**{
            "field": "name",
            "type": "nominal"
        }),
        "y":
        alt.Y(**{
            "field": "amount",
            "type": "ordinal"
        }),
        "fill":
        alt.Fill(**{
            "field": "amount",
            "type": "quantitative",
            "aggregate": "max",
        }),
        "size":
        alt.Size(**{
            "field": "amount",
            "type": "ordinal",
            "aggregate": "mean",
        }),
    }
def generate_ridgeline_plot(data, x_lab_country_name):
    """A function that generates a ridgeline plot for covid_19 CAN & USA dataset.

    Parameters
    ----------
    data
        input data set from preprocessed csv.
    x_lab_country_name
        name of the country for which we want to generate the ridgeline plot

    Returns
    -------
    altair object
        returns the plot as a altair object
    """
    step = 40
    overlap = 1

    ridgeline_plt = alt.Chart(
        data, height=step
    ).transform_timeunit(Month='month(date)').transform_joinaggregate(
        mean_response_ratio='mean(response_ratio)', groupby=['Month']
    ).transform_bin([
        'bin_max', 'bin_min'
    ], 'response_ratio').transform_aggregate(
        value='count()',
        groupby=['Month', 'mean_response_ratio', 'bin_min', 'bin_max']
    ).transform_impute(
        impute='value',
        groupby=['Month', 'mean_response_ratio'
                 ],
        key='bin_min',
        value=0).mark_area(
            interpolate='monotone',
            fillOpacity=0.8,
            stroke='lightgray',
            strokeWidth=0.5).encode(
                alt.X('bin_min:Q',
                      bin='binned',
                      title=f'Mean Response Ratio in {x_lab_country_name}'),
                alt.Y('value:Q',
                      scale=alt.Scale(range=[step, -step * overlap]),
                      axis=None),
                alt.Fill('mean_response_ratio:Q')).facet(
                    row=alt.Row('Month:T',
                                title=None,
                                header=alt.Header(labelAngle=0,
                                                  labelAlign='right',
                                                  format='%B'))).properties(
                                                      title='', bounds='flush')

    return ridgeline_plt
def plot_ridgeline(input_df, time_unit_field, value_col):
    """ graph ridgeline plot of given dataframe, time unit field and value column"""
    df = input_df.copy()
    df[time_unit_field] = df[time_unit_field].apply(transform_date)
    step = 30  # adjust height of each kde
    overlap = 1
    to_transform = 'mean(' + value_col + ')'
    ridgeline = alt.Chart(df, height=step).transform_timeunit(
        as_="Decade", timeUnit="year", field=time_unit_field
    ).transform_joinaggregate(mean_val=to_transform, groupby=[
        "Decade"
    ]).transform_bin(
        ['bin_max', 'bin_min'], value_col, bin=alt.Bin(maxbins=10)
    ).transform_aggregate(
        value='count()',
        groupby=["Decade", 'mean_val', 'bin_min', 'bin_max']
    ).transform_impute(
        impute='value',
        groupby=[
            'Decade',
            'mean_val'
        ], key='bin_min', value=0
    ).mark_area(
        interpolate='monotone',
        fillOpacity=0.8,
        stroke='lightgray',
        strokeWidth=0.5
    ).encode(
        alt.
        X('bin_min:Q', bin='binned',
          title='Timbre 2 Average By Decades'),
        alt.
        Y('value:Q', scale=alt.Scale(range=[step, -step * overlap]),
          axis=None),
        alt.Fill(
            'mean_val:Q',
            legend=None,
            scale=alt.Scale(
                domain=[50, -100],
                scheme='redyellowblue')  # adjust color 
        )).facet(row=alt.Row(
            'Decade:T',  # only accepts T type: convert things to T type first 
            title=None,
            header=alt.
            Header(labelAngle=0, labelAlign='right', format='%Y' +
                   "s"))).properties(
                       title='Timbre Avg by Decade (Ridgeline)',
                       bounds='flush').configure_facet(
                           spacing=0).configure_view(
                               stroke=None).configure_title(anchor='end')
    return ridgeline
Esempio n. 5
0
def visualize_emb(vis_dict):
  dict = vis_dict['dict']
  c_names = vis_dict['c_names']
  emb_vis_data = pd.DataFrame(dict)
  step = 20
  overlap = 1
  emb_chart = alt.Chart(emb_vis_data).transform_fold(
      c_names,
      as_=['embedding', 'lv']
    ).mark_area(
      interpolate='monotone',
      fillOpacity=0.8,
      stroke='lightgray',
      strokeWidth=0.2
    ).encode(
      # x='x',
      # y='lv:Q',
      # alt.Color('embedding:N'),
      alt.X('x:Q', title=None,
            scale=alt.Scale(domain=[0,512], range=[0,1500])),
      alt.Y(
          'lv:Q',
          title="",
          scale=alt.Scale(rangeStep=40),
          # scale=alt.Scale(range=[step, -step * overlap]),
          axis=None
      ),
      alt.Fill(
          'embedding:N',
          legend=None,
          scale=alt.Scale(scheme='redyellowblue')
      ),
      row=alt.Row(
           'embedding:N',
           title=None,
           header=alt.Header(labelAngle=360)
       )
   ).properties(
       bounds='flush', title='Вектор статьи', height=step, width=1200
  ).configure_facet(
      spacing=0
  ).configure_view(
      stroke=None
  ).configure_title(
      anchor='middle'
  )
  st.altair_chart(emb_chart, width=-1)
def generate_ridgeline_plot(data, attribute):
    '''ridge line plot: multiple histograms overlaps'''

    step = 40
    overlap = 1

    graph = alt.Chart(data).transform_joinaggregate(
        mean_attribute="mean({})".format(str(attribute)),
        groupby=['species']).transform_bin([
            'bin_max', 'bin_min'
        ], str(attribute)).transform_aggregate(
            value='count()',
            groupby=[
                'species', 'mean_attribute', 'bin_min', 'bin_max'
            ]).transform_impute(
                impute='value',
                groupby=['species', 'mean_attribute'],
                key='bin_min',
                value=0).mark_area(
                    interpolate='monotone',
                    fillOpacity=0.4,
                    stroke='lightgray',
                    strokeWidth=0.3).encode(
                        alt.X('bin_min:Q', bin='binned', title=str(attribute)),
                        alt.Y('value:Q',
                              scale=alt.Scale(range=[step, -step * overlap]),
                              axis=None),
                        alt.Fill('mean_attribute:Q',
                                 legend=None,
                                 scale=alt.Scale(domain=[30, 5],
                                                 scheme='redyellowblue')),
                        alt.Row(
                            'species:O',
                            title='Species',
                            header=alt.Header(
                                labelAngle=0,
                                labelAlign='right'))).properties(
                                    bounds='flush',
                                    title='Comparison: {}'.format(
                                        str(metadata_description[attribute])),
                                    height=100,
                                    width=700,
                                ).configure_facet(spacing=0, ).configure_view(
                                    stroke=None, ).configure_title(
                                        anchor='end')

    return graph
Esempio n. 7
0
def grafico_con_columnas(
        source,
        y='test_acc',
        title="Exactitud en validacion desacoplado segun destilacion",
        shape='student',
        column='feat_dist',
        color='layer',
        fill=None):

    #reduce data
    d = locals()
    ks = [
        i for i in d.keys()
        if i not in ['source', 'title', 'xscale', 'yscale', 'scale', 'bs']
    ]
    vals = [d[i] for i in ks]
    source = source.drop(columns=[i for i in source.columns if i not in vals])

    encodings = {
        "shape":
        alt.Shape("%s:O" % shape,
                  legend=alt.Legend(title=global_titles[shape])),
        "y":
        alt.Y(y, title=global_titles[y]),
        "column":
        alt.Column('%s:O' % column, title=global_titles[column]),
        "x":
        alt.X('%s:N' % color, title=global_titles[color]),
        "color":
        alt.Color('%s:N' % color,
                  legend=alt.Legend(title=global_titles[color])),
        "opacity":
        alt.value(0.5)
    }

    if fill is not None:
        encodings["fill"] = alt.Fill(
            '%s:O' % fill,
            legend=alt.Legend(title=global_titles[fill]),
            scale=alt.Scale(scheme='pastel1'))
    d1 = alt.Chart(
        source,
        title=title).mark_point(size=100).encode(**encodings).configure_axis(
            titleFontSize=12,
            labelFontSize=12).configure_title(fontSize=15).interactive()
    return d1.properties(width=70, height=600)
Esempio n. 8
0
def probabilistic_chart(
    probability_scale_range: Tuple[float, float],
    belief_horizon_unit: str,
    sensor_name: str,
    sensor_unit: str,
):
    base_chart = alt.Chart().encode(
        x=alt.X(
            "event_value:Q",
            bin="binned",
            scale=alt.Scale(padding=0),
            title=sensor_name + " (" + sensor_unit + ")",
        ),
        y=alt.Y(
            "probability:Q", scale=alt.Scale(range=probability_scale_range), axis=None
        ),
    )
    line_chart = base_chart.mark_line(interpolate="monotone").encode(
        stroke=alt.condition(
            selectors.ridgeline_hover_brush, alt.value("black"), alt.value("lightgray")
        ),
        strokeWidth=alt.condition(
            selectors.ridgeline_hover_brush, alt.value(2.5), alt.value(0.5)
        ),
    )
    area_chart = base_chart.mark_area(interpolate="monotone", fillOpacity=0.6).encode(
        fill=alt.Fill(
            "belief_horizon:N",
            sort="ascending",
            legend=None,
            scale=alt.Scale(scheme="viridis"),
        ),
        tooltip=[
            alt.Tooltip("event_value:Q", title="Value", format=".2f"),
            alt.Tooltip("probability:Q", title="Probability", format=".2f"),
            alt.Tooltip(
                "belief_horizon:Q",
                title="%s (%s)" % ("Belief horizon", belief_horizon_unit),
            ),
        ],
    )
    return alt.layer(area_chart, line_chart)
def totals_by_stage(d: pd.DataFrame) -> alt.Chart:
    """Plots total runtimes for each stage.

  Args:
    d: A dataframe of runtimes.

  Returns:
    An altair chart.
  """
    stage_totals_series = d.sum()[RUNTIME_COLUMNS]
    stage_totals = pd.DataFrame(stage_totals_series,
                                columns=['Runtime (seconds)'])
    stage_totals.reset_index(inplace=True)
    stage_totals = stage_totals.rename(columns={'index': 'Stage'})
    stage_totals['Runtime'] = stage_totals['Runtime (seconds)'].apply(
        format_runtime_string)
    return alt.Chart(stage_totals).mark_bar().encode(
        x='Runtime (seconds)',
        y=alt.Y('Stage', sort=None),
        tooltip=['Runtime'],
        fill=alt.Fill('Stage',
                      sort=None)).properties(title='Overall runtime by stage')
Esempio n. 10
0
def individual_region_bars(small_df: pd.DataFrame,
                           title: Union[str, Dict[str, str]] = '') -> alt.Chart:
  """Makes a stacked bar chart with runtime of each stage for individual regions.

  Args:
    small_df: A dataframe of regions, each of which will be shown as a bar.
    title: A title for the plot. If a dict, it should contain 'title' and/or
      'subtitle'.

  Returns:
    An altair chart.
  """
  columns_used = ['region', 'Runtime'] + RUNTIME_COLUMNS
  d = small_df[columns_used]
  return alt.Chart(d).transform_fold(
      RUNTIME_COLUMNS, as_=['Stage', 'runtime_by_stage']) \
    .mark_bar().encode(
        x=alt.X('region:N', sort=None),
        y=alt.Y('runtime_by_stage:Q', scale=alt.Scale(type='linear'), title='Runtime (seconds)'),
        fill=alt.Fill('Stage:N', sort=None),
        tooltip='Runtime:N'
    ).properties(title=title)
Esempio n. 11
0
def test_titled_vars(mydata):
    opts = {
        "xvar": "name",
        "yvar": "amount:O|The Amount",
        "fillvar": "category|Cats",
        "sizevar": "mean(amount):O| Big Ups!",
    }
    cg = ChannelGroup(opts, mydata)

    assert cg == {
        "x":
        alt.X(**{
            "field": "name",
            "type": "nominal"
        }),
        "y":
        alt.Y(**{
            "field": "amount",
            "title": "The Amount",
            "type": "ordinal",
        }),
        "fill":
        alt.Fill(**{
            "field": "category",
            "title": "Cats",
            "type": "nominal",
        }),
        "size":
        alt.Size(
            **{
                "field": "amount",
                "type": "ordinal",
                "aggregate": "mean",
                "title": " Big Ups!",
            }),
    }
Esempio n. 12
0
area = base_wheat.mark_area(**{
    "color": "#a4cedb",
    "opacity": 0.7
}).encode(x=alt.X("year:Q"), y=alt.Y("wages:Q"))

area_line_1 = area.mark_line(**{"color": "#000", "opacity": 0.7})
area_line_2 = area.mark_line(**{"yOffset": -2, "color": "#EE8182"})

top_bars = base_monarchs.mark_bar(stroke="#000").encode(
    x=alt.X("start:Q"),
    x2=alt.X2("end"),
    y=alt.Y("y:Q"),
    y2=alt.Y2("offset"),
    fill=alt.Fill("commonwealth:N",
                  legend=None,
                  scale=alt.Scale(range=["black", "white"])))

top_text = base_monarchs.mark_text(**{
    "yOffset": 14,
    "fontSize": 9,
    "fontStyle": "italic"
}).encode(x=alt.X("x:Q"), y=alt.Y("off2:Q"), text=alt.Text("name:N"))

(bars + area + area_line_1 + area_line_2 + top_bars +
 top_text).properties(width=900,
                      height=400).configure_axis(
                          title=None,
                          gridColor="white",
                          gridOpacity=0.25,
                          domain=False).configure_view(stroke="transparent")
Esempio n. 13
0
                'Month', 'mean_temp', 'bin_min', 'bin_max'
            ]).transform_impute(
                impute='value',
                groupby=['Month', 'mean_temp'],
                key='bin_min',
                value=0).mark_area(
                    interpolate='monotone',
                    fillOpacity=0.8,
                    stroke='lightgray',
                    strokeWidth=0.5).encode(
                        alt.X('bin_min:Q',
                              bin='binned',
                              title='Maximum Daily Temperature (C)'),
                        alt.Y('value:Q',
                              scale=alt.Scale(range=[step, -step * overlap]),
                              axis=None),
                        alt.Fill('mean_temp:Q',
                                 legend=None,
                                 scale=alt.Scale(domain=[30, 5],
                                                 scheme='redyellowblue'))
                    ).facet(row=alt.Row('Month:T',
                                        title=None,
                                        header=alt.Header(labelAngle=0,
                                                          labelAlign='right',
                                                          format='%B'))
                            ).properties(title='Seattle Weather',
                                         bounds='flush').configure_facet(
                                             spacing=0).configure_view(
                                                 stroke=None).configure_title(
                                                     anchor='end')
Esempio n. 14
0

@pytest.mark.xfail(raises=NotImplementedError,
                   reason="specifying timeUnit is not supported yet")
def test_timeUnit():
    chart = alt.Chart(df).mark_point().encode(alt.X('date(combination)'))
    convert(chart)


# Plots

chart_quant = alt.Chart(df_quant).mark_point().encode(
    alt.X(field='a', type='quantitative'), alt.Y('b'), alt.Color('c:Q'),
    alt.Size('s'))
chart_fill_quant = alt.Chart(df_quant).mark_point().encode(
    alt.X(field='a', type='quantitative'), alt.Y('b'), alt.Fill('fill:Q'))


@pytest.mark.parametrize("chart", [chart_quant, chart_fill_quant])
def test_quantitative_scatter(chart):
    mapping = convert(chart)
    plt.scatter(**mapping)
    plt.show()


@pytest.mark.parametrize("channel", [alt.Color("years"), alt.Fill("years")])
def test_scatter_temporal(channel):
    chart = alt.Chart(df).mark_point().encode(alt.X("years"), channel)
    mapping = convert(chart)
    mapping['y'] = df['quantitative'].values
    plt.scatter(**mapping)
Esempio n. 15
0
def test_convert_fill_success_temporal(column):
    chart = alt.Chart(df).mark_point().encode(alt.Fill(column))
    mapping = convert(chart)
    assert list(mapping['c']) == list(mdates.date2num(df[column].values))
Esempio n. 16
0
def create_map(alcohol_type='beer', region="World"):
    """
    Create choropleth heatmap based on alcoholic consumption

    Parameters
    ----------
    alcohol_type : str {‘wine’, ‘beer’, 'spirit'}
        Type of alcohol to show on choropleth.

    Returns
    -------
    altair Chart object
        Choropleth of chosen alcohol type
    Examples
    --------
    >>> create_map('spirit')
    """

    region_dict = {
        "World": [140, 450, 400, 'the World'],
        "Asia": [400, -190, 520, 'Asia'],
        "Europe": [800, 300, 1100, 'Europe'],
        "Africa": [400, 300, 310, 'Africa'],
        "Americas": [275, 950, 310, 'the Americas'],
        "Oceania": [500, -800, 50, 'Oceania']
    }

    # set colour scheme of map
    if alcohol_type == 'wine':
        map_color = ['#f9f9f9', '#720b18']
    elif alcohol_type == 'beer':
        map_color = ['#f9f9f9', '#DAA520']
    else:
        map_color = ['#f9f9f9', '#67b2e5', '#1f78b5']

    cols = [x for x in df.columns if alcohol_type in x]
    cols.append('country')

    # this is to select the rank column to sort
    if region == 'World':
        col_to_filter = cols[2]
    else:
        col_to_filter = cols[3]

    # Create map plot
    map_plot = alt.Chart(
        alt.topo_feature(data.world_110m.url, 'countries')
    ).mark_geoshape(stroke='white', strokeWidth=0.5).encode(
        alt.Color(
            field=cols[1],
            type='quantitative',
            scale=alt.Scale(domain=[0, 1], range=map_color),
            legend=alt.Legend(
                orient='top',
                title=
                f'Proportion of total servings per person from {alcohol_type}')
        ),
        tooltip=[
            {
                "field": cols[4],
                "type": "nominal",
                'title': "Country"
            },
            {
                "field": cols[1],
                "type": "quantitative",
                'title': f'Proportion of total servings from {alcohol_type}',
                'format': '.2f'
            },
            {
                "field": cols[0],
                "type": "quantitative",
                'title': f'Total {alcohol_type} servings'
            },
            {
                "field": cols[3],
                "type": "quantitative",
                'title': 'Continent rank'
            },
            {
                "field": cols[2],
                "type": "quantitative",
                'title': 'Global rank'
            },
        ]).transform_lookup(
            lookup='id', from_=alt.LookupData(df, 'id', fields=cols)).project(
                type='mercator',
                scale=region_dict[region][0],
                translate=[region_dict[region][1],
                           region_dict[region][2]]).properties(
                               width=900,
                               height=600,
                           )

    bar = alt.Chart(df).mark_bar().encode(
        alt.X(
            field=cols[1],
            type='quantitative',
            title="",
            scale=alt.Scale(domain=[0, 1]),
        ),
        alt.Y(field='country',
              type='nominal',
              sort=alt.EncodingSortField(field=cols[1],
                                         op='max',
                                         order='descending'),
              title=''),
        alt.Fill(field=cols[1],
                 type='quantitative',
                 scale=alt.Scale(domain=[0, 1], range=map_color),
                 legend=None),
        tooltip=[
            {
                "field": cols[4],
                "type": "nominal",
                'title': "Country"
            },
            {
                "field": cols[1],
                "type": "quantitative",
                'title':
                f'Proportion of total servings per person from {alcohol_type}',
                'format': '.2f'
            },
            {
                "field": cols[0],
                "type": "quantitative",
                'title': f'Total {alcohol_type} servings'
            },
            {
                "field": cols[3],
                "type": "quantitative",
                'title': 'Continent rank'
            },
            {
                "field": cols[2],
                "type": "quantitative",
                'title': 'Global rank'
            },
        ]
    ).transform_filter(
        alt.
        datum.region == region
        if region != 'World' else alt.datum.total_servings >= 0
    ).transform_window(
        sort=[alt.SortField(cols[1], order="descending")
              ],
        rank="rank(col_to_filter)"
    ).transform_filter(
        alt.
        datum.rank <= 20
    ).properties(
        title=
        f"Top 20 Countries that love {alcohol_type.title()} in {region_dict[region][3]}",
        width=200,
        height=600)

    return alt.hconcat(map_plot, bar).configure_legend(
        gradientLength=300,
        gradientThickness=20,
        titleLimit=0,
        labelFontSize=15,
        titleFontSize=20).configure_axis(
            labelFontSize=15, titleFontSize=20).configure_title(fontSize=20)
Esempio n. 17
0
def uv_ridgePlot(data, engine, xlabel, ylabel, afreq):
    data = data.copy()
    data.rename(columns={'plotX1': ylabel}, inplace=True)
    if data['anfreq_label'].nunique() > 15:
        engine = 'Interactive'

    if engine == 'Static':

        sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
        # Initialize the FacetGrid object
        pal = sns.cubehelix_palette(10, rot=-.25, light=.7)
        g = sns.FacetGrid(data,
                          row="anfreq_label",
                          hue="anfreq_label",
                          aspect=15,
                          height=.5,
                          palette=pal)

        # Draw the densities in a few steps
        g.map(sns.kdeplot,
              ylabel,
              bw_adjust=.5,
              clip_on=False,
              fill=True,
              alpha=1,
              linewidth=1.5)
        g.map(sns.kdeplot,
              ylabel,
              clip_on=False,
              color="w",
              lw=2,
              bw_adjust=.5)
        g.map(plt.axhline, y=0, lw=2, clip_on=False)

        # Define and use a simple function to label the plot in axes coordinates
        def label(x, color, label):
            ax = plt.gca()
            ax.text(0,
                    .2,
                    label,
                    fontweight="bold",
                    color=color,
                    ha="left",
                    va="center",
                    transform=ax.transAxes)

        g.map(label, ylabel)

        # Set the subplots to overlap
        g.fig.subplots_adjust(hspace=-.25)

        # Remove axes details that don't play well with overlap
        g.set_titles("")
        g.set(yticks=[])
        g.despine(bottom=True, left=True)
        plt.close()
        return pn.pane.Matplotlib(g.fig, tight=True)

    elif engine == 'Interactive':

        step = 30
        overlap = 2
        data = data.dropna()
        min_cval = data[ylabel].min()
        max_cval = data[ylabel].max()
        ridgeline = alt.Chart(data, height=step)
        ridgeline = ridgeline.mark_area(interpolate="monotone",
                                        fillOpacity=0.8,
                                        stroke="lightgray",
                                        strokeWidth=0.5)
        ridgeline = ridgeline.encode(
            alt.X("{0}:Q".format(ylabel),
                  bin=True,
                  title=ylabel,
                  axis=alt.Axis(format='~s')))
        ridgeline = ridgeline.encode(
            alt.Y("count({0}):Q".format(ylabel),
                  scale=alt.Scale(range=[step, -step * overlap]),
                  impute=alt.ImputeParams(value=0),
                  axis=None))
        ridgeline = ridgeline.encode(
            alt.Fill("mean({0}):Q".format(ylabel),
                     legend=None,
                     scale=alt.Scale(domain=[max_cval, min_cval],
                                     scheme="redyellowblue")))
        if afreq not in ['Month Start', 'Month End']:
            ridgeline = ridgeline.encode(
                alt.Row("{0}:N".format('anfreq_label'),
                        header=alt.Header(labelAngle=0, labelAlign="left")))
        else:
            ridgeline = ridgeline.encode(
                alt.Row("{0}:N".format('anfreq_label'),
                        title=afreq,
                        sort=[
                            'January', 'February', 'March', 'April', 'May',
                            'June', 'July', 'August', 'September', 'October',
                            'November', 'December'
                        ],
                        header=alt.Header(labelAngle=0, labelAlign="left")))
        ridgeline = ridgeline.properties(bounds="flush", width=525)
        ridgeline = ridgeline.configure_facet(spacing=0)

        return ridgeline
Esempio n. 18
0
def altair_frozen_weights_performance_ridge_plot(data, xaxis_title = "Dev Metric", title_main = "Dense Variably Unfrozen", task_name = "MSR", 
    step_all = 75, width_all = 600, step_small = 30, width_small = 400, overlap = 1, max_bins = 30, color_scheme = 'redyellowblue', return_all = True):

    assert type(data) is pd.core.frame.DataFrame, "Parameter `data` must be of type pandas.core.frame.DataFrame."
    assert all(e in data.columns.to_list() for e in ['Frozen Weights Pct', 'Epoch', 'Dev Metric']), "Parameter `data` must contain the following columns: ['Frozen Weights Pct', 'Epoch', 'Dev Metric']."

    # generate the combined epochs plot
    domain_ = [min(data['Dev Metric']), max(data['Dev Metric'])]
    c0 = alt.Chart(data, height=step_all)\
        .transform_joinaggregate(mean_acc='mean(Dev Metric)', groupby=['Frozen Weights Pct'])\
        .transform_bin(['bin_max', 'bin_min'], 'Dev Metric', bin=alt.Bin(maxbins=max_bins))\
        .transform_aggregate(value='count()', groupby=['Frozen Weights Pct', 'mean_acc', 'bin_min', 'bin_max'])\
        .transform_impute(impute='value', groupby=['Frozen Weights Pct', 'mean_acc'], key='bin_min', value=domain_[0])\
        .mark_area(interpolate='monotone', fillOpacity=0.8, stroke='lightgray', strokeWidth=0.5)\
        .encode(
            alt.X('bin_min:Q', bin='binned', title=xaxis_title, scale=alt.Scale(domain=domain_)),
            alt.Y('value:Q', scale=alt.Scale(range=[step_all, -step_all * overlap]), axis=None),
            alt.Fill('mean_acc:Q', legend=None,scale=alt.Scale(domain=[sum(x) for x in zip(domain_[::-1], [-0.05, 0.05])], scheme=color_scheme)))\
        .properties(width = width_all, height = step_all)\
        .facet(
            row=alt.Row(
                'Frozen Weights Pct:O',
                title='Forzen Weights Pct (Binned)',
                header=alt.Header(
                    labelAngle=0, labelAlign='right', labelFontSize=15, labelFont='Lato', labelColor=berkeley_palette['pacific'], titleFontSize=20
                )
            )
        ).properties(title={'text':title_main, 'subtitle': " ".join([task_name,"- All Epochs"])}, bounds='flush')
        

    # if not returning all plots, then return the main "All Epochs" plot
    if not (return_all):
        return c0.configure_facet(spacing=0).configure_view(stroke=None).configure_title(anchor='middle')
    
    # generate the individual epochs plots
    subplots = [None] * 4
    for i in range(1,5):

        domain_ = [min(data[(data['Epoch'] == i)]['Dev Metric']), max(data[(data['Epoch'] == i)]['Dev Metric'])]

        o = alt.Chart(data[(data['Epoch'] == i)], height=step_small)\
            .transform_joinaggregate(mean_acc='mean(Dev Metric)', groupby=['Frozen Weights Pct'])\
            .transform_bin(['bin_max', 'bin_min'], 'Dev Metric', bin=alt.Bin(maxbins=max_bins))\
            .transform_aggregate(value='count()', groupby=['Frozen Weights Pct', 'mean_acc', 'bin_min', 'bin_max'])\
            .transform_impute(impute='value', groupby=['Frozen Weights Pct', 'mean_acc'], key='bin_min', value=domain_[0])\
            .mark_area(interpolate='monotone', fillOpacity=0.8, stroke='lightgray', strokeWidth=0.5)\
            .encode(
                alt.X('bin_min:Q', bin='binned', title=xaxis_title, scale=alt.Scale(domain=domain_)),
                alt.Y('value:Q', scale=alt.Scale(range=[step_small, -step_small * overlap]), axis=None),
                alt.Fill('mean_acc:Q', legend=None, scale=alt.Scale(domain=[sum(x) for x in zip(domain_[::-1], [-0.05, 0.05])], scheme=color_scheme)))\
            .properties(width = width_small, height = step_small)\
            .facet(
                row=alt.Row(
                    'Frozen Weights Pct:O',
                    title='Forzen Weights Pct (Binned)',
                    header=alt.Header(
                        labelAngle=0, labelAlign='right', labelFontSize=15, labelFont='Lato', labelColor=berkeley_palette['pacific'], titleFontSize=20
                    )
                )
            ).properties(title={'text':title_main, 'subtitle': " ".join([task_name, "- Epoch", str(i)])}, bounds='flush')

        subplots[i-1] = o

    viz = alt.hconcat(alt.vconcat(alt.hconcat(subplots[0], subplots[1]), alt.hconcat(subplots[2], subplots[3])), c0)\
        .configure_facet(spacing=0)\
        .configure_view(stroke=None)\
        .configure_title(anchor='middle')

    return viz
Esempio n. 19
0
def grafico_con_barra_xy_CE(
        source=data['students']['kd']['KD'],
        title="Ratio entre exactitud en validacion y accuracy en entrenamiento.",
        y="test/train",
        x="temp",
        color='student',
        shape='log_dist',
        bs=True,
        fill=None,
        xscale='log',
        yscale='linear'):

    #reduce data
    d = locals()
    ks = [
        i for i in d.keys()
        if i not in ['source', 'title', 'xscale', 'yscale', 'scale', 'bs']
    ]
    vals = [d[i] for i in ks]
    source = source.drop(columns=[i for i in source.columns if i not in vals])

    ytitle = global_titles[
        y] if yscale == 'linear' else global_titles[y] + " [%s]" % yscale
    xtitle = global_titles[
        x] if xscale == 'linear' else global_titles[x] + " [%s]" % xscale
    encodings = {
        'y':
        alt.Y(y,
              type='quantitative',
              title=ytitle,
              scale=alt.Scale(zero=False, base=10, type=yscale)),
        'x':
        alt.X(x,
              type='quantitative',
              title=xtitle,
              scale=alt.Scale(zero=False, base=10, type=xscale)),
        'color':
        alt.Color(color, legend=alt.Legend(title=global_titles[color])),
        'shape':
        alt.Shape(shape, legend=alt.Legend(title=global_titles[shape])),
        #size=50
    }

    if fill is not None:
        encodings["fill"] = alt.Fill(fill)

    chart = alt.Chart(
        source,
        title=title).mark_point(size=100).encode(**encodings).interactive()
    if bs:
        bar = alt.Chart(bar_source).mark_rule(opacity=0.5).encode(
            y=y,
            color=alt.Color('model', legend=alt.Legend(title="Modelo")),
            stroke=alt.Stroke(
                'model', legend=alt.Legend(title="Modelo en Cross Entropy")),
            size=alt.value(2))
        d = bar + chart
        d.properties(width=600, height=600).configure_axis(
            titleFontSize=12).configure_title(fontSize=15)

        return d
    chart.properties(width=600, height=600).configure_axis(
        titleFontSize=12).configure_title(fontSize=15)
    return chart
Esempio n. 20
0
def grafico_con_barra_y_CE(
        source=data['students']['kd']['KD'],
        title="Diferencia de accuracy en entrenamiento con respecto a Resnet101.",
        y="delta_teacher",
        x="feat_dist",
        fill='layer',
        color='student',
        shape='layer',
        bs=True,
        scale='linear'):

    if x == 'dists':
        source = source.copy()
        source['dists'] = source['feat_dist'] + ", " + source['log_dist']
    elif x == 'feat,block':
        source = source.copy()
        source['feat,block'] = [
            i + ", " + str(j)
            for i, j in list(zip(source['feat_dist'], source['layer']))
        ]

    #reduce data
    d = locals()
    ks = [
        i for i in d.keys()
        if i not in ['source', 'title', 'xscale', 'yscale', 'scale', 'bs']
    ]
    vals = [d[i] for i in ks]
    source = source.drop(columns=[i for i in source.columns if i not in vals])

    chart = alt.Chart(source, title=title).mark_point(size=100).encode(
        y=alt.Y(
            y,
            type='quantitative',
            scale=alt.Scale(zero=True, base=2, constant=1, type=scale),
            title=global_titles[y] if scale == 'linear' else global_titles[y] +
            " [%s]" % scale),
        x=alt.X(x, title=global_titles[x]),
        color=alt.Color(color, legend=alt.Legend(title=global_titles[color])),
        fill=alt.Fill('%s:O' % fill,
                      legend=alt.Legend(title=global_titles[fill]),
                      scale=alt.Scale(scheme='pastel1')),
        shape=alt.Shape("%s:O" % shape,
                        legend=alt.Legend(title=global_titles[shape])),
        opacity=alt.value(0.5))

    bar = alt.Chart(bar_source).mark_rule(opacity=0.5).encode(
        y=y,
        color=alt.Color('model', legend=alt.Legend(title="Modelo")),
        stroke=alt.Stroke('model',
                          legend=alt.Legend(title="Modelo en Cross Entropy")),
        size=alt.value(2))

    if bs:
        bar = alt.Chart(bar_source).mark_rule(opacity=0.5).encode(
            y=y,
            color=alt.Color('model', legend=alt.Legend(title="Modelo")),
            stroke=alt.Stroke(
                'model', legend=alt.Legend(title="Modelo en Cross Entropy")),
            size=alt.value(2))
        d = bar + chart
        d.properties(width=600, height=600).configure_axis(
            titleFontSize=12).configure_title(fontSize=15)

        return d

    chart.properties(width=600, height=600).configure_axis(
        titleFontSize=12).configure_title(fontSize=15)
    return chart