Exemplo n.º 1
0
def _plot_with_data_bar(source: alt.Chart,
                        line: alt.Chart,
                        type_: ValueType,
                        width: Optional[int] = None,
                        height: Optional[int] = None) -> alt.Chart:

    width = 800 if width is None else width
    height = 400 if height is None else height

    # TODO: Add validation that the x and y series names are correct

    # Create a selection that chooses the nearest point & selects based on x-value
    nearest = alt.selection(type='single', nearest=True, on='mouseover', fields=['days'],
                            empty='none')

    # Transparent selectors across the chart. This is what tells us
    # the x-value of the cursor
    selectors = alt.Chart(source).mark_point().encode(
        x='days:Q',
        opacity=alt.value(0),
    ).add_selection(
        nearest
    )

    # Draw points on the line, and highlight based on selection
    points = line.mark_point().encode(
        opacity=alt.condition(nearest, alt.value(1), alt.value(0))
    )

    # Draw text labels near the points, and highlight based on selection
    text = line.mark_text(align='left', dx=5, dy=10).encode(
        text=alt.condition(nearest, f'{type_.value}:Q', alt.value(' '), format=".0f")
    )

    # Draw a rule at the location of the selection
    rules = alt.Chart(source).mark_rule(color='gray').encode(
        x='days:Q',
    ).transform_filter(
        nearest
    )

    # Put the five layers into a chart and bind the data
    chart = alt.layer(
        line, selectors, points, rules, text
    ).properties(
        width=width, height=height
    )

    return chart
Exemplo n.º 2
0
def confusion_matrix(df=None, truth=None, pred=None, mapping=None):
    if df is None:
        df = pd.DataFrame({'truth': truth, 'pred': pred})
        truth = 'truth'
        pred = 'pred'
    threshold = len(df)
    if mapping:
        assert isinstance(mapping, dict), 'mapping should be a dictionary'
        df[truth] = df[truth].map(lambda x: mapping[x])
        df[pred] = df[pred].map(lambda x: mapping[x])

    sz = 450 if len(df[truth].unique()) > 4 else 250
    base = Chart(df, height=sz, width=sz).transform_aggregate(
        num_vals='count()', groupby=[truth, pred]).transform_calculate(
            rev_num_vals='-(datum.num_vals) + max(datum.num_vals)', ).encode(
                alt.Y(f'{truth}:O', scale=alt.Scale(paddingInner=0)),
                alt.X(f'{pred}:O', scale=alt.Scale(paddingInner=0)),
            )

    hm = base.mark_rect().encode(color=alt.Color(
        'num_vals:Q', scale=alt.Scale(scheme="lightorange"), legend=None))

    tx = base.mark_text(baseline='middle').encode(
        text='num_vals:Q',
        #         color=alt.Color(alt.value('gray'))
        #         color='rev_num_vals:Q'
        #         color=alt.Color(
        #             'num_vals:Q', scale=alt.Scale(scheme="redyellowgreen"),
        #         )
        color=alt.condition(alt.datum.num_vals > threshold, alt.value('black'),
                            alt.value('black')))

    try:
        from sklearn.metrics import classification_report
        print(classification_report(df[truth], df[pred]))
    except:
        logger.info('Skipping Report')
    return hm + tx