示例#1
0
def plot_pay_gap_2016(jobs_gender, major_category):
    return (
        jobs_gender >> filter(_.year == 2016, _.total_workers >= 20000) >>
        filter(_.major_category == major_category)
        #     >> arrange(desc(wage_percent_of_male))
        >> mutate(
            percent_female=_.workers_female / _.total_workers,
            wage_percent_female=_.total_earnings_female /
            _.total_earnings_male,
        ) >> ggplot(
            aes(
                "percent_female",
                "wage_percent_female",
                color="minor_category",
                size="total_workers",
                label="occupation",
            )) + geom_point() +
        scale_size_continuous(range=[1, 10], guide=False) + labs(
            x="% of workforce reported as female",
            y="% of median female salary / median male",
            title="Gender disparity and pay gap in 2016",
            subtitle="Only occupations with at least 20,000 workers total",
            color="Minor category",
        )
        #        scale_x_continuous(labels = percent_format()) +
        #        scale_y_continuous(labels = percent_format())
    )
示例#2
0
def test_filter_via_group_by_desc_arrange(backend):
    dfs = backend.load_df(x=[3, 2, 1] + [2, 3, 4], g=[1] * 3 + [2] * 3)

    assert_equal_query(
        dfs,
        group_by(_.g) >> arrange(desc(_.x)) >> filter(_.x.cumsum() > 3),
        data_frame(x=[2, 1, 4, 3, 2], g=[1, 1, 2, 2, 2]))
示例#3
0
def test_filter_via_group_by_agg_two_args(backend):
    dfs = backend.load_df(x=range(1, 11), g=[1] * 5 + [2] * 5)

    assert_equal_query(
        dfs,
        group_by(_.g) >> filter(_.x > _.x.mean(), _.x != _.x.max()),
        data_frame(x=[4, 9], g=[1, 2]))
示例#4
0
def test_filter_via_group_by(backend):
    df = data_frame(x=range(1, 11), g=[1] * 5 + [2] * 5)

    dfs = backend.load_df(df)

    assert_equal_query(dfs,
                       group_by(_.g) >> filter(row_number(_) < 3),
                       data_frame(x=[1, 2, 6, 7], g=[1, 1, 2, 2]))
示例#5
0
def update_graph(val_cyl):

    from siuba.data import mtcars
    from siuba import _, filter
    import plotly.tools as tls

    p = (mtcars >> filter(_.cyl == val_cyl) >> ggplot(aes("hp", "mpg")) +
         geom_point() + ggtitle("Hp vs Mpg for cyl = %s" % val_cyl))

    return tls.mpl_to_plotly(p.draw())
示例#6
0
def test_filter_vector(backend, func, simple_data):
    if backend.name == 'sqlite':
        pytest.skip()

    df = backend.load_cached_df(simple_data)

    res = data_frame(y=func(simple_data))

    assert_equal_query(
        df,
        filter(func),
        filter(simple_data, func),
        # ignore dtypes, since sql -> an empty data frame has object columns
        check_dtype=False)

    # grouped (vs slow_filter)
    assert_equal_query(df,
                       group_by(_.g) >> filter(func),
                       simple_data >> group_by(_.g) >> filter(func),
                       check_dtype=False)
示例#7
0
def distinct_events(tbl, time_col, user_col, type):
    if type not in ["first", "last"]:
        return tbl

    res = (tbl
            >> group_by(_[user_col])
            >> arrange(_[time_col] if type == "first" else -_[time_col])
            >> filter(row_number(_) == 1)
            >> ungroup()
            )

    return res
示例#8
0
def test_select_mutate_filter(dfs):
    assert_equal_query(
        dfs,
        select(_.x == _.a) >> mutate(y=_.x * 2) >> filter(_.y == 2),
        data_frame(x=1, y=2))
# -

top250 = data_top250()
game_goals = data_game_goals()
# top250 =     pd.read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-03-03/top_250.csv') \
#         .rename(columns = {'total_games': 'total_goals'})
# game_goals = pd.read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-03-03/game_goals.csv')

# +
# filter top 8 scorers

top8 = (
    top250 >>
    mutate(dense_rank=_.total_goals.rank(method="dense", ascending=False)) >>
    filter(_.dense_rank < 10)
    #     >> select(_.raw_rank, _.dense_rank, _.total_goals, _.player)
)

top8_games = game_goals >> inner_join(_, top8, "player")

# +
st.write("Goals by month")

st.write("Top 8 players not in our data")
top8 >> filter(_.yr_start < 1979)

# +
from pandas.tseries.offsets import MonthBegin
from siuba.experimental.pd_groups import fast_summarize
示例#10
0
def test_filter_via_group_by_agg(backend):
    dfs = backend.load_df(x=range(1, 11), g=[1] * 5 + [2] * 5)

    assert_equal_query(dfs,
                       group_by(_.g) >> filter(_.x > _.x.mean()),
                       data_frame(x=[4, 5, 9, 10], g=[1, 1, 2, 2]))
示例#11
0
def test_filter_basic(backend):
    df = data_frame(x=[1, 2, 3, 4, 5], y=[5, 4, 3, 2, 1])
    dfs = backend.load_df(df)

    assert_equal_query(dfs, filter(_.x > 3), df[lambda _: _.x > 3])
    )


data_load_state = st.text('Loading data...')
jobs_gender = load_data()
data_load_state.text('Loading data... done!')
# -

if st.checkbox('Show raw data'):
    st.subheader('Raw data')
    st.write(jobs_gender)

# +
major_category = st.sidebar.selectbox('Which number do you like best?',
                                      jobs_gender.major_category.unique())

minor_options = jobs_gender \
        .loc[lambda d: d.major_category == major_category, 'minor_category'] \
        .unique()

minor_category = st.sidebar.selectbox('Minor category', minor_options)

# +

filtered_jobs = filter(jobs_gender, _.major_category == major_category,
                       _.minor_category == minor_category)

p = plot_pay_gap_2016(filtered_jobs)

st.plotly_chart(tools.mpl_to_plotly(p.draw()))
示例#13
0
def test_filter_basic_two_args(backend):
    df = data_frame(x=[1, 2, 3, 4, 5], y=[5, 4, 3, 2, 1])
    dfs = backend.load_df(df)

    assert_equal_query(dfs, filter(_.x > 3, _.y < 2),
                       df[lambda _: (_.x > 3) & (_.y < 2)])
示例#14
0
def test_raw_sql_filter(backend, df):
    assert_equal_query(df, filter(sql_raw("y = 1")), data_frame(x=['a'],
                                                                y=[1]))
示例#15
0
def after_join(
        lhs, rhs,
        by_time, by_user,
        mode = "inner",
        type = "first-firstafter",
        max_gap = None,
        min_gap = None,
        gap_col = None,
        suffix = ("_x", "_y")
        ):

    if max_gap is not None or min_gap is not None or gap_col is not None:
        raise NotImplementedError("max_gap, min_gap, gap_col not implemented")

    # Get type of join for both tables, from e.g. "first-firstafter"
    type_lhs, type_rhs = type.split("-")

    # Convert join keys to dictionary form
    by_time_x, by_time_y = _get_key_tuple(by_time)
    by_user_x, by_user_y = _get_key_tuple(by_user)

    # mutate in row_number ----
    lhs_i = (lhs
            >> arrange(_[by_user_x], _[by_time_x])
            >> mutate(__idx = row_number(_))
            >> distinct_events(by_time_x, by_user_x, type_lhs)
            )

    rhs_i = (rhs
            >> arrange(_[by_user_y], _[by_time_y])
            >> mutate(__idy = row_number(_))
            >> distinct_events(by_time_y, by_user_y, type_rhs)
            )

    # Handle when time column is in the other table
    if by_time_x == by_time_y:
        # TODO: don't use implicit join suffix below
        pair_time_x, pair_time_y = by_time_x + "_x", by_time_y + "_y"
    else:
        pair_time_x, pair_time_y = by_time_x, by_time_y

    # Inner join by user, filter by time
    pairs = filter(
            inner_join(lhs_i, rhs_i, by_user),
            _[pair_time_x] <= _[pair_time_y]
            )

    # TODO: firstwithin
    if type_lhs in ["firstwithin", "lastbefore"]:
        raise NotImplementedError("Can't currently handle lhs type %s" % type_lhs)

    # Handle firstafter by subsetting
    if type_rhs == "firstafter":
        pairs = (pairs
                >> arrange(_[pair_time_y])
                >> group_by(_.__idx)
                >> filter(row_number(_) == 1)
                >> ungroup()
                )


    distinct_pairs = select(pairs, _.__idx, _.__idy)


    if mode in ["inner", "left", "right", "full", "outer"]:
        by_dict = dict([(by_user_x, by_user_y), ("__idy", "__idy")])
        res = (lhs_i
                >> join(_, distinct_pairs, on = "__idx", how = mode) 
                # TODO: suffix arg
                >> join(_, rhs_i , on = by_dict, how = mode)#, suffix = suffix)
                >> select(-_["__idx", "__idy"])
                )
    elif mode in ["semi", "anti"]:
        join_func = semi_join if mode == "semi" else anti_join
        res = (lhs_i
                >> join_func(_, distinct_pairs, "__idx")
                >> select(-_["__idx", "__idy"])
                )

    else:
        raise ValueError("mode not recognized: %s" %mode)

    return res
示例#16
0
@st.cache
def load_data(nrows):
    data = pd.read_csv(DATA_URL, nrows=nrows)
    lowercase = lambda x: str(x).lower()
    data.rename(lowercase, axis='columns', inplace=True)
    data[DATE_COLUMN] = pd.to_datetime(data[DATE_COLUMN])
    return data


data_load_state = st.text('Loading data...')
data = load_data(10000)
data_load_state.text('Loading data... done!')
# -

if st.checkbox('Show raw data'):
    st.subheader('Raw data')
    st.write(data)

option = st.sidebar.selectbox('Which number do you like best?', [4, 6, 8])

# +
from siuba.data import mtcars
import plotly.tools as tls

p = mtcars >> filter(
    _.cyl == option) >> ggplot(aes('hp', 'mpg')) + geom_point()

st.pyplot(p.draw())

#st.plotly_chart()