def ridge_plot(d, value, groupby, step=30, overlap=0.8, sort=None): return ( alt.Chart(d) .transform_joinaggregate(mean_value=f"mean({value})", groupby=[groupby]) .transform_bin(["bin_max", "bin_min"], value) .transform_aggregate( value="count()", groupby=[groupby, "mean_value", "bin_min", "bin_max"] ) .transform_impute( impute="value", groupby=[groupby, "mean_value"], key="bin_min", value=0 ) .mark_area( interpolate="monotone", fillOpacity=0.8, stroke="lightgray", strokeWidth=0.5 ) .encode( alt.X("bin_min:Q", bin="binned", title='activation', axis=alt.Axis(format='%', labelFlush=False)), alt.Y("value:Q", scale=alt.Scale(range=[step, -step * overlap]), axis=None), alt.Fill( "mean_value:Q", legend=None, scale=alt.Scale( domain=[d[value].max(), d[value].min()], scheme="redyellowblue" ), ), alt.Row( f"{groupby}:N", title=None, sort=alt.SortArray(sort) if sort else None, header=alt.Header(labelAngle=0, labelAlign="right", format="%B"), ), ) .properties(bounds="flush", height=step) .configure_facet(spacing=0) .configure_view(stroke=None) )
def test_shorthand_typecodes(mydata): opts = { "xvar": "name", "yvar": "amount:O", "fillvar": "max(amount)", "sizevar": "mean(amount):O", } cg = ChannelGroup(opts, mydata) assert cg == { "x": alt.X(**{ "field": "name", "type": "nominal" }), "y": alt.Y(**{ "field": "amount", "type": "ordinal" }), "fill": alt.Fill(**{ "field": "amount", "type": "quantitative", "aggregate": "max", }), "size": alt.Size(**{ "field": "amount", "type": "ordinal", "aggregate": "mean", }), }
def generate_ridgeline_plot(data, x_lab_country_name): """A function that generates a ridgeline plot for covid_19 CAN & USA dataset. Parameters ---------- data input data set from preprocessed csv. x_lab_country_name name of the country for which we want to generate the ridgeline plot Returns ------- altair object returns the plot as a altair object """ step = 40 overlap = 1 ridgeline_plt = alt.Chart( data, height=step ).transform_timeunit(Month='month(date)').transform_joinaggregate( mean_response_ratio='mean(response_ratio)', groupby=['Month'] ).transform_bin([ 'bin_max', 'bin_min' ], 'response_ratio').transform_aggregate( value='count()', groupby=['Month', 'mean_response_ratio', 'bin_min', 'bin_max'] ).transform_impute( impute='value', groupby=['Month', 'mean_response_ratio' ], key='bin_min', value=0).mark_area( interpolate='monotone', fillOpacity=0.8, stroke='lightgray', strokeWidth=0.5).encode( alt.X('bin_min:Q', bin='binned', title=f'Mean Response Ratio in {x_lab_country_name}'), alt.Y('value:Q', scale=alt.Scale(range=[step, -step * overlap]), axis=None), alt.Fill('mean_response_ratio:Q')).facet( row=alt.Row('Month:T', title=None, header=alt.Header(labelAngle=0, labelAlign='right', format='%B'))).properties( title='', bounds='flush') return ridgeline_plt
def plot_ridgeline(input_df, time_unit_field, value_col): """ graph ridgeline plot of given dataframe, time unit field and value column""" df = input_df.copy() df[time_unit_field] = df[time_unit_field].apply(transform_date) step = 30 # adjust height of each kde overlap = 1 to_transform = 'mean(' + value_col + ')' ridgeline = alt.Chart(df, height=step).transform_timeunit( as_="Decade", timeUnit="year", field=time_unit_field ).transform_joinaggregate(mean_val=to_transform, groupby=[ "Decade" ]).transform_bin( ['bin_max', 'bin_min'], value_col, bin=alt.Bin(maxbins=10) ).transform_aggregate( value='count()', groupby=["Decade", 'mean_val', 'bin_min', 'bin_max'] ).transform_impute( impute='value', groupby=[ 'Decade', 'mean_val' ], key='bin_min', value=0 ).mark_area( interpolate='monotone', fillOpacity=0.8, stroke='lightgray', strokeWidth=0.5 ).encode( alt. X('bin_min:Q', bin='binned', title='Timbre 2 Average By Decades'), alt. Y('value:Q', scale=alt.Scale(range=[step, -step * overlap]), axis=None), alt.Fill( 'mean_val:Q', legend=None, scale=alt.Scale( domain=[50, -100], scheme='redyellowblue') # adjust color )).facet(row=alt.Row( 'Decade:T', # only accepts T type: convert things to T type first title=None, header=alt. Header(labelAngle=0, labelAlign='right', format='%Y' + "s"))).properties( title='Timbre Avg by Decade (Ridgeline)', bounds='flush').configure_facet( spacing=0).configure_view( stroke=None).configure_title(anchor='end') return ridgeline
def visualize_emb(vis_dict): dict = vis_dict['dict'] c_names = vis_dict['c_names'] emb_vis_data = pd.DataFrame(dict) step = 20 overlap = 1 emb_chart = alt.Chart(emb_vis_data).transform_fold( c_names, as_=['embedding', 'lv'] ).mark_area( interpolate='monotone', fillOpacity=0.8, stroke='lightgray', strokeWidth=0.2 ).encode( # x='x', # y='lv:Q', # alt.Color('embedding:N'), alt.X('x:Q', title=None, scale=alt.Scale(domain=[0,512], range=[0,1500])), alt.Y( 'lv:Q', title="", scale=alt.Scale(rangeStep=40), # scale=alt.Scale(range=[step, -step * overlap]), axis=None ), alt.Fill( 'embedding:N', legend=None, scale=alt.Scale(scheme='redyellowblue') ), row=alt.Row( 'embedding:N', title=None, header=alt.Header(labelAngle=360) ) ).properties( bounds='flush', title='Вектор статьи', height=step, width=1200 ).configure_facet( spacing=0 ).configure_view( stroke=None ).configure_title( anchor='middle' ) st.altair_chart(emb_chart, width=-1)
def generate_ridgeline_plot(data, attribute): '''ridge line plot: multiple histograms overlaps''' step = 40 overlap = 1 graph = alt.Chart(data).transform_joinaggregate( mean_attribute="mean({})".format(str(attribute)), groupby=['species']).transform_bin([ 'bin_max', 'bin_min' ], str(attribute)).transform_aggregate( value='count()', groupby=[ 'species', 'mean_attribute', 'bin_min', 'bin_max' ]).transform_impute( impute='value', groupby=['species', 'mean_attribute'], key='bin_min', value=0).mark_area( interpolate='monotone', fillOpacity=0.4, stroke='lightgray', strokeWidth=0.3).encode( alt.X('bin_min:Q', bin='binned', title=str(attribute)), alt.Y('value:Q', scale=alt.Scale(range=[step, -step * overlap]), axis=None), alt.Fill('mean_attribute:Q', legend=None, scale=alt.Scale(domain=[30, 5], scheme='redyellowblue')), alt.Row( 'species:O', title='Species', header=alt.Header( labelAngle=0, labelAlign='right'))).properties( bounds='flush', title='Comparison: {}'.format( str(metadata_description[attribute])), height=100, width=700, ).configure_facet(spacing=0, ).configure_view( stroke=None, ).configure_title( anchor='end') return graph
def grafico_con_columnas( source, y='test_acc', title="Exactitud en validacion desacoplado segun destilacion", shape='student', column='feat_dist', color='layer', fill=None): #reduce data d = locals() ks = [ i for i in d.keys() if i not in ['source', 'title', 'xscale', 'yscale', 'scale', 'bs'] ] vals = [d[i] for i in ks] source = source.drop(columns=[i for i in source.columns if i not in vals]) encodings = { "shape": alt.Shape("%s:O" % shape, legend=alt.Legend(title=global_titles[shape])), "y": alt.Y(y, title=global_titles[y]), "column": alt.Column('%s:O' % column, title=global_titles[column]), "x": alt.X('%s:N' % color, title=global_titles[color]), "color": alt.Color('%s:N' % color, legend=alt.Legend(title=global_titles[color])), "opacity": alt.value(0.5) } if fill is not None: encodings["fill"] = alt.Fill( '%s:O' % fill, legend=alt.Legend(title=global_titles[fill]), scale=alt.Scale(scheme='pastel1')) d1 = alt.Chart( source, title=title).mark_point(size=100).encode(**encodings).configure_axis( titleFontSize=12, labelFontSize=12).configure_title(fontSize=15).interactive() return d1.properties(width=70, height=600)
def probabilistic_chart( probability_scale_range: Tuple[float, float], belief_horizon_unit: str, sensor_name: str, sensor_unit: str, ): base_chart = alt.Chart().encode( x=alt.X( "event_value:Q", bin="binned", scale=alt.Scale(padding=0), title=sensor_name + " (" + sensor_unit + ")", ), y=alt.Y( "probability:Q", scale=alt.Scale(range=probability_scale_range), axis=None ), ) line_chart = base_chart.mark_line(interpolate="monotone").encode( stroke=alt.condition( selectors.ridgeline_hover_brush, alt.value("black"), alt.value("lightgray") ), strokeWidth=alt.condition( selectors.ridgeline_hover_brush, alt.value(2.5), alt.value(0.5) ), ) area_chart = base_chart.mark_area(interpolate="monotone", fillOpacity=0.6).encode( fill=alt.Fill( "belief_horizon:N", sort="ascending", legend=None, scale=alt.Scale(scheme="viridis"), ), tooltip=[ alt.Tooltip("event_value:Q", title="Value", format=".2f"), alt.Tooltip("probability:Q", title="Probability", format=".2f"), alt.Tooltip( "belief_horizon:Q", title="%s (%s)" % ("Belief horizon", belief_horizon_unit), ), ], ) return alt.layer(area_chart, line_chart)
def totals_by_stage(d: pd.DataFrame) -> alt.Chart: """Plots total runtimes for each stage. Args: d: A dataframe of runtimes. Returns: An altair chart. """ stage_totals_series = d.sum()[RUNTIME_COLUMNS] stage_totals = pd.DataFrame(stage_totals_series, columns=['Runtime (seconds)']) stage_totals.reset_index(inplace=True) stage_totals = stage_totals.rename(columns={'index': 'Stage'}) stage_totals['Runtime'] = stage_totals['Runtime (seconds)'].apply( format_runtime_string) return alt.Chart(stage_totals).mark_bar().encode( x='Runtime (seconds)', y=alt.Y('Stage', sort=None), tooltip=['Runtime'], fill=alt.Fill('Stage', sort=None)).properties(title='Overall runtime by stage')
def individual_region_bars(small_df: pd.DataFrame, title: Union[str, Dict[str, str]] = '') -> alt.Chart: """Makes a stacked bar chart with runtime of each stage for individual regions. Args: small_df: A dataframe of regions, each of which will be shown as a bar. title: A title for the plot. If a dict, it should contain 'title' and/or 'subtitle'. Returns: An altair chart. """ columns_used = ['region', 'Runtime'] + RUNTIME_COLUMNS d = small_df[columns_used] return alt.Chart(d).transform_fold( RUNTIME_COLUMNS, as_=['Stage', 'runtime_by_stage']) \ .mark_bar().encode( x=alt.X('region:N', sort=None), y=alt.Y('runtime_by_stage:Q', scale=alt.Scale(type='linear'), title='Runtime (seconds)'), fill=alt.Fill('Stage:N', sort=None), tooltip='Runtime:N' ).properties(title=title)
def test_titled_vars(mydata): opts = { "xvar": "name", "yvar": "amount:O|The Amount", "fillvar": "category|Cats", "sizevar": "mean(amount):O| Big Ups!", } cg = ChannelGroup(opts, mydata) assert cg == { "x": alt.X(**{ "field": "name", "type": "nominal" }), "y": alt.Y(**{ "field": "amount", "title": "The Amount", "type": "ordinal", }), "fill": alt.Fill(**{ "field": "category", "title": "Cats", "type": "nominal", }), "size": alt.Size( **{ "field": "amount", "type": "ordinal", "aggregate": "mean", "title": " Big Ups!", }), }
area = base_wheat.mark_area(**{ "color": "#a4cedb", "opacity": 0.7 }).encode(x=alt.X("year:Q"), y=alt.Y("wages:Q")) area_line_1 = area.mark_line(**{"color": "#000", "opacity": 0.7}) area_line_2 = area.mark_line(**{"yOffset": -2, "color": "#EE8182"}) top_bars = base_monarchs.mark_bar(stroke="#000").encode( x=alt.X("start:Q"), x2=alt.X2("end"), y=alt.Y("y:Q"), y2=alt.Y2("offset"), fill=alt.Fill("commonwealth:N", legend=None, scale=alt.Scale(range=["black", "white"]))) top_text = base_monarchs.mark_text(**{ "yOffset": 14, "fontSize": 9, "fontStyle": "italic" }).encode(x=alt.X("x:Q"), y=alt.Y("off2:Q"), text=alt.Text("name:N")) (bars + area + area_line_1 + area_line_2 + top_bars + top_text).properties(width=900, height=400).configure_axis( title=None, gridColor="white", gridOpacity=0.25, domain=False).configure_view(stroke="transparent")
'Month', 'mean_temp', 'bin_min', 'bin_max' ]).transform_impute( impute='value', groupby=['Month', 'mean_temp'], key='bin_min', value=0).mark_area( interpolate='monotone', fillOpacity=0.8, stroke='lightgray', strokeWidth=0.5).encode( alt.X('bin_min:Q', bin='binned', title='Maximum Daily Temperature (C)'), alt.Y('value:Q', scale=alt.Scale(range=[step, -step * overlap]), axis=None), alt.Fill('mean_temp:Q', legend=None, scale=alt.Scale(domain=[30, 5], scheme='redyellowblue')) ).facet(row=alt.Row('Month:T', title=None, header=alt.Header(labelAngle=0, labelAlign='right', format='%B')) ).properties(title='Seattle Weather', bounds='flush').configure_facet( spacing=0).configure_view( stroke=None).configure_title( anchor='end')
@pytest.mark.xfail(raises=NotImplementedError, reason="specifying timeUnit is not supported yet") def test_timeUnit(): chart = alt.Chart(df).mark_point().encode(alt.X('date(combination)')) convert(chart) # Plots chart_quant = alt.Chart(df_quant).mark_point().encode( alt.X(field='a', type='quantitative'), alt.Y('b'), alt.Color('c:Q'), alt.Size('s')) chart_fill_quant = alt.Chart(df_quant).mark_point().encode( alt.X(field='a', type='quantitative'), alt.Y('b'), alt.Fill('fill:Q')) @pytest.mark.parametrize("chart", [chart_quant, chart_fill_quant]) def test_quantitative_scatter(chart): mapping = convert(chart) plt.scatter(**mapping) plt.show() @pytest.mark.parametrize("channel", [alt.Color("years"), alt.Fill("years")]) def test_scatter_temporal(channel): chart = alt.Chart(df).mark_point().encode(alt.X("years"), channel) mapping = convert(chart) mapping['y'] = df['quantitative'].values plt.scatter(**mapping)
def test_convert_fill_success_temporal(column): chart = alt.Chart(df).mark_point().encode(alt.Fill(column)) mapping = convert(chart) assert list(mapping['c']) == list(mdates.date2num(df[column].values))
def create_map(alcohol_type='beer', region="World"): """ Create choropleth heatmap based on alcoholic consumption Parameters ---------- alcohol_type : str {‘wine’, ‘beer’, 'spirit'} Type of alcohol to show on choropleth. Returns ------- altair Chart object Choropleth of chosen alcohol type Examples -------- >>> create_map('spirit') """ region_dict = { "World": [140, 450, 400, 'the World'], "Asia": [400, -190, 520, 'Asia'], "Europe": [800, 300, 1100, 'Europe'], "Africa": [400, 300, 310, 'Africa'], "Americas": [275, 950, 310, 'the Americas'], "Oceania": [500, -800, 50, 'Oceania'] } # set colour scheme of map if alcohol_type == 'wine': map_color = ['#f9f9f9', '#720b18'] elif alcohol_type == 'beer': map_color = ['#f9f9f9', '#DAA520'] else: map_color = ['#f9f9f9', '#67b2e5', '#1f78b5'] cols = [x for x in df.columns if alcohol_type in x] cols.append('country') # this is to select the rank column to sort if region == 'World': col_to_filter = cols[2] else: col_to_filter = cols[3] # Create map plot map_plot = alt.Chart( alt.topo_feature(data.world_110m.url, 'countries') ).mark_geoshape(stroke='white', strokeWidth=0.5).encode( alt.Color( field=cols[1], type='quantitative', scale=alt.Scale(domain=[0, 1], range=map_color), legend=alt.Legend( orient='top', title= f'Proportion of total servings per person from {alcohol_type}') ), tooltip=[ { "field": cols[4], "type": "nominal", 'title': "Country" }, { "field": cols[1], "type": "quantitative", 'title': f'Proportion of total servings from {alcohol_type}', 'format': '.2f' }, { "field": cols[0], "type": "quantitative", 'title': f'Total {alcohol_type} servings' }, { "field": cols[3], "type": "quantitative", 'title': 'Continent rank' }, { "field": cols[2], "type": "quantitative", 'title': 'Global rank' }, ]).transform_lookup( lookup='id', from_=alt.LookupData(df, 'id', fields=cols)).project( type='mercator', scale=region_dict[region][0], translate=[region_dict[region][1], region_dict[region][2]]).properties( width=900, height=600, ) bar = alt.Chart(df).mark_bar().encode( alt.X( field=cols[1], type='quantitative', title="", scale=alt.Scale(domain=[0, 1]), ), alt.Y(field='country', type='nominal', sort=alt.EncodingSortField(field=cols[1], op='max', order='descending'), title=''), alt.Fill(field=cols[1], type='quantitative', scale=alt.Scale(domain=[0, 1], range=map_color), legend=None), tooltip=[ { "field": cols[4], "type": "nominal", 'title': "Country" }, { "field": cols[1], "type": "quantitative", 'title': f'Proportion of total servings per person from {alcohol_type}', 'format': '.2f' }, { "field": cols[0], "type": "quantitative", 'title': f'Total {alcohol_type} servings' }, { "field": cols[3], "type": "quantitative", 'title': 'Continent rank' }, { "field": cols[2], "type": "quantitative", 'title': 'Global rank' }, ] ).transform_filter( alt. datum.region == region if region != 'World' else alt.datum.total_servings >= 0 ).transform_window( sort=[alt.SortField(cols[1], order="descending") ], rank="rank(col_to_filter)" ).transform_filter( alt. datum.rank <= 20 ).properties( title= f"Top 20 Countries that love {alcohol_type.title()} in {region_dict[region][3]}", width=200, height=600) return alt.hconcat(map_plot, bar).configure_legend( gradientLength=300, gradientThickness=20, titleLimit=0, labelFontSize=15, titleFontSize=20).configure_axis( labelFontSize=15, titleFontSize=20).configure_title(fontSize=20)
def uv_ridgePlot(data, engine, xlabel, ylabel, afreq): data = data.copy() data.rename(columns={'plotX1': ylabel}, inplace=True) if data['anfreq_label'].nunique() > 15: engine = 'Interactive' if engine == 'Static': sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) # Initialize the FacetGrid object pal = sns.cubehelix_palette(10, rot=-.25, light=.7) g = sns.FacetGrid(data, row="anfreq_label", hue="anfreq_label", aspect=15, height=.5, palette=pal) # Draw the densities in a few steps g.map(sns.kdeplot, ylabel, bw_adjust=.5, clip_on=False, fill=True, alpha=1, linewidth=1.5) g.map(sns.kdeplot, ylabel, clip_on=False, color="w", lw=2, bw_adjust=.5) g.map(plt.axhline, y=0, lw=2, clip_on=False) # Define and use a simple function to label the plot in axes coordinates def label(x, color, label): ax = plt.gca() ax.text(0, .2, label, fontweight="bold", color=color, ha="left", va="center", transform=ax.transAxes) g.map(label, ylabel) # Set the subplots to overlap g.fig.subplots_adjust(hspace=-.25) # Remove axes details that don't play well with overlap g.set_titles("") g.set(yticks=[]) g.despine(bottom=True, left=True) plt.close() return pn.pane.Matplotlib(g.fig, tight=True) elif engine == 'Interactive': step = 30 overlap = 2 data = data.dropna() min_cval = data[ylabel].min() max_cval = data[ylabel].max() ridgeline = alt.Chart(data, height=step) ridgeline = ridgeline.mark_area(interpolate="monotone", fillOpacity=0.8, stroke="lightgray", strokeWidth=0.5) ridgeline = ridgeline.encode( alt.X("{0}:Q".format(ylabel), bin=True, title=ylabel, axis=alt.Axis(format='~s'))) ridgeline = ridgeline.encode( alt.Y("count({0}):Q".format(ylabel), scale=alt.Scale(range=[step, -step * overlap]), impute=alt.ImputeParams(value=0), axis=None)) ridgeline = ridgeline.encode( alt.Fill("mean({0}):Q".format(ylabel), legend=None, scale=alt.Scale(domain=[max_cval, min_cval], scheme="redyellowblue"))) if afreq not in ['Month Start', 'Month End']: ridgeline = ridgeline.encode( alt.Row("{0}:N".format('anfreq_label'), header=alt.Header(labelAngle=0, labelAlign="left"))) else: ridgeline = ridgeline.encode( alt.Row("{0}:N".format('anfreq_label'), title=afreq, sort=[ 'January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December' ], header=alt.Header(labelAngle=0, labelAlign="left"))) ridgeline = ridgeline.properties(bounds="flush", width=525) ridgeline = ridgeline.configure_facet(spacing=0) return ridgeline
def altair_frozen_weights_performance_ridge_plot(data, xaxis_title = "Dev Metric", title_main = "Dense Variably Unfrozen", task_name = "MSR", step_all = 75, width_all = 600, step_small = 30, width_small = 400, overlap = 1, max_bins = 30, color_scheme = 'redyellowblue', return_all = True): assert type(data) is pd.core.frame.DataFrame, "Parameter `data` must be of type pandas.core.frame.DataFrame." assert all(e in data.columns.to_list() for e in ['Frozen Weights Pct', 'Epoch', 'Dev Metric']), "Parameter `data` must contain the following columns: ['Frozen Weights Pct', 'Epoch', 'Dev Metric']." # generate the combined epochs plot domain_ = [min(data['Dev Metric']), max(data['Dev Metric'])] c0 = alt.Chart(data, height=step_all)\ .transform_joinaggregate(mean_acc='mean(Dev Metric)', groupby=['Frozen Weights Pct'])\ .transform_bin(['bin_max', 'bin_min'], 'Dev Metric', bin=alt.Bin(maxbins=max_bins))\ .transform_aggregate(value='count()', groupby=['Frozen Weights Pct', 'mean_acc', 'bin_min', 'bin_max'])\ .transform_impute(impute='value', groupby=['Frozen Weights Pct', 'mean_acc'], key='bin_min', value=domain_[0])\ .mark_area(interpolate='monotone', fillOpacity=0.8, stroke='lightgray', strokeWidth=0.5)\ .encode( alt.X('bin_min:Q', bin='binned', title=xaxis_title, scale=alt.Scale(domain=domain_)), alt.Y('value:Q', scale=alt.Scale(range=[step_all, -step_all * overlap]), axis=None), alt.Fill('mean_acc:Q', legend=None,scale=alt.Scale(domain=[sum(x) for x in zip(domain_[::-1], [-0.05, 0.05])], scheme=color_scheme)))\ .properties(width = width_all, height = step_all)\ .facet( row=alt.Row( 'Frozen Weights Pct:O', title='Forzen Weights Pct (Binned)', header=alt.Header( labelAngle=0, labelAlign='right', labelFontSize=15, labelFont='Lato', labelColor=berkeley_palette['pacific'], titleFontSize=20 ) ) ).properties(title={'text':title_main, 'subtitle': " ".join([task_name,"- All Epochs"])}, bounds='flush') # if not returning all plots, then return the main "All Epochs" plot if not (return_all): return c0.configure_facet(spacing=0).configure_view(stroke=None).configure_title(anchor='middle') # generate the individual epochs plots subplots = [None] * 4 for i in range(1,5): domain_ = [min(data[(data['Epoch'] == i)]['Dev Metric']), max(data[(data['Epoch'] == i)]['Dev Metric'])] o = alt.Chart(data[(data['Epoch'] == i)], height=step_small)\ .transform_joinaggregate(mean_acc='mean(Dev Metric)', groupby=['Frozen Weights Pct'])\ .transform_bin(['bin_max', 'bin_min'], 'Dev Metric', bin=alt.Bin(maxbins=max_bins))\ .transform_aggregate(value='count()', groupby=['Frozen Weights Pct', 'mean_acc', 'bin_min', 'bin_max'])\ .transform_impute(impute='value', groupby=['Frozen Weights Pct', 'mean_acc'], key='bin_min', value=domain_[0])\ .mark_area(interpolate='monotone', fillOpacity=0.8, stroke='lightgray', strokeWidth=0.5)\ .encode( alt.X('bin_min:Q', bin='binned', title=xaxis_title, scale=alt.Scale(domain=domain_)), alt.Y('value:Q', scale=alt.Scale(range=[step_small, -step_small * overlap]), axis=None), alt.Fill('mean_acc:Q', legend=None, scale=alt.Scale(domain=[sum(x) for x in zip(domain_[::-1], [-0.05, 0.05])], scheme=color_scheme)))\ .properties(width = width_small, height = step_small)\ .facet( row=alt.Row( 'Frozen Weights Pct:O', title='Forzen Weights Pct (Binned)', header=alt.Header( labelAngle=0, labelAlign='right', labelFontSize=15, labelFont='Lato', labelColor=berkeley_palette['pacific'], titleFontSize=20 ) ) ).properties(title={'text':title_main, 'subtitle': " ".join([task_name, "- Epoch", str(i)])}, bounds='flush') subplots[i-1] = o viz = alt.hconcat(alt.vconcat(alt.hconcat(subplots[0], subplots[1]), alt.hconcat(subplots[2], subplots[3])), c0)\ .configure_facet(spacing=0)\ .configure_view(stroke=None)\ .configure_title(anchor='middle') return viz
def grafico_con_barra_xy_CE( source=data['students']['kd']['KD'], title="Ratio entre exactitud en validacion y accuracy en entrenamiento.", y="test/train", x="temp", color='student', shape='log_dist', bs=True, fill=None, xscale='log', yscale='linear'): #reduce data d = locals() ks = [ i for i in d.keys() if i not in ['source', 'title', 'xscale', 'yscale', 'scale', 'bs'] ] vals = [d[i] for i in ks] source = source.drop(columns=[i for i in source.columns if i not in vals]) ytitle = global_titles[ y] if yscale == 'linear' else global_titles[y] + " [%s]" % yscale xtitle = global_titles[ x] if xscale == 'linear' else global_titles[x] + " [%s]" % xscale encodings = { 'y': alt.Y(y, type='quantitative', title=ytitle, scale=alt.Scale(zero=False, base=10, type=yscale)), 'x': alt.X(x, type='quantitative', title=xtitle, scale=alt.Scale(zero=False, base=10, type=xscale)), 'color': alt.Color(color, legend=alt.Legend(title=global_titles[color])), 'shape': alt.Shape(shape, legend=alt.Legend(title=global_titles[shape])), #size=50 } if fill is not None: encodings["fill"] = alt.Fill(fill) chart = alt.Chart( source, title=title).mark_point(size=100).encode(**encodings).interactive() if bs: bar = alt.Chart(bar_source).mark_rule(opacity=0.5).encode( y=y, color=alt.Color('model', legend=alt.Legend(title="Modelo")), stroke=alt.Stroke( 'model', legend=alt.Legend(title="Modelo en Cross Entropy")), size=alt.value(2)) d = bar + chart d.properties(width=600, height=600).configure_axis( titleFontSize=12).configure_title(fontSize=15) return d chart.properties(width=600, height=600).configure_axis( titleFontSize=12).configure_title(fontSize=15) return chart
def grafico_con_barra_y_CE( source=data['students']['kd']['KD'], title="Diferencia de accuracy en entrenamiento con respecto a Resnet101.", y="delta_teacher", x="feat_dist", fill='layer', color='student', shape='layer', bs=True, scale='linear'): if x == 'dists': source = source.copy() source['dists'] = source['feat_dist'] + ", " + source['log_dist'] elif x == 'feat,block': source = source.copy() source['feat,block'] = [ i + ", " + str(j) for i, j in list(zip(source['feat_dist'], source['layer'])) ] #reduce data d = locals() ks = [ i for i in d.keys() if i not in ['source', 'title', 'xscale', 'yscale', 'scale', 'bs'] ] vals = [d[i] for i in ks] source = source.drop(columns=[i for i in source.columns if i not in vals]) chart = alt.Chart(source, title=title).mark_point(size=100).encode( y=alt.Y( y, type='quantitative', scale=alt.Scale(zero=True, base=2, constant=1, type=scale), title=global_titles[y] if scale == 'linear' else global_titles[y] + " [%s]" % scale), x=alt.X(x, title=global_titles[x]), color=alt.Color(color, legend=alt.Legend(title=global_titles[color])), fill=alt.Fill('%s:O' % fill, legend=alt.Legend(title=global_titles[fill]), scale=alt.Scale(scheme='pastel1')), shape=alt.Shape("%s:O" % shape, legend=alt.Legend(title=global_titles[shape])), opacity=alt.value(0.5)) bar = alt.Chart(bar_source).mark_rule(opacity=0.5).encode( y=y, color=alt.Color('model', legend=alt.Legend(title="Modelo")), stroke=alt.Stroke('model', legend=alt.Legend(title="Modelo en Cross Entropy")), size=alt.value(2)) if bs: bar = alt.Chart(bar_source).mark_rule(opacity=0.5).encode( y=y, color=alt.Color('model', legend=alt.Legend(title="Modelo")), stroke=alt.Stroke( 'model', legend=alt.Legend(title="Modelo en Cross Entropy")), size=alt.value(2)) d = bar + chart d.properties(width=600, height=600).configure_axis( titleFontSize=12).configure_title(fontSize=15) return d chart.properties(width=600, height=600).configure_axis( titleFontSize=12).configure_title(fontSize=15) return chart