def cat_plot( data: pd.DataFrame, figsize: Tuple = (18, 18), top: int = 3, bottom: int = 3, bar_color_top: str = "#5ab4ac", bar_color_bottom: str = "#d8b365", ): """ Two-dimensional visualization of the number and frequency of categorical features. Parameters ---------- data : pd.DataFrame 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the \ index/column information is used to label the plots figsize : Tuple, optional Use to control the figure size, by default (18, 18) top : int, optional Show the "top" most frequent values in a column, by default 3 bottom : int, optional Show the "bottom" most frequent values in a column, by default 3 bar_color_top : str, optional Use to control the color of the bars indicating the most common values, by default "#5ab4ac" bar_color_bottom : str, optional Use to control the color of the bars indicating the least common values, by default "#d8b365" cmap : str, optional The mapping from data values to color space, by default "BrBG" Returns ------- Gridspec gs: Figure with array of Axes objects """ # Validate Inputs _validate_input_int(top, "top") _validate_input_int(bottom, "bottom") _validate_input_range(top, "top", 0, data.shape[1]) _validate_input_range(bottom, "bottom", 0, data.shape[1]) _validate_input_sum_larger(1, "top and bottom", top, bottom) data = pd.DataFrame(data).copy() cols = data.select_dtypes(exclude=["number"]).columns.tolist() data = data[cols] for col in data.columns: if data[col].dtype.name == "category" or data[ col].dtype.name == "string": data[col] = data[col].astype("object") if len(cols) == 0: print("No columns with categorical data were detected.") fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.21) for count, col in enumerate(cols): n_unique = data[col].nunique(dropna=True) value_counts = data[col].value_counts() lim_top, lim_bot = top, bottom if n_unique < top + bottom: lim_top = int(n_unique // 2) lim_bot = int(n_unique // 2) + 1 if n_unique <= 2: lim_top = lim_bot = int(n_unique // 2) value_counts_top = value_counts[0:lim_top] value_counts_idx_top = value_counts_top.index.tolist() value_counts_bot = value_counts[-lim_bot:] value_counts_idx_bot = value_counts_bot.index.tolist() if top == 0: value_counts_top = value_counts_idx_top = [] if bottom == 0: value_counts_bot = value_counts_idx_bot = [] data.loc[data[col].isin(value_counts_idx_top), col] = 10 data.loc[data[col].isin(value_counts_idx_bot), col] = 0 data.loc[((data[col] != 10) & (data[col] != 0)), col] = 5 data[col] = data[col].rolling(2, min_periods=1).mean() value_counts_idx_top = [elem[:20] for elem in value_counts_idx_top] value_counts_idx_bot = [elem[:20] for elem in value_counts_idx_bot] # Barcharts ax_top = fig.add_subplot(gs[:1, count:count + 1]) ax_top.bar(value_counts_idx_top, value_counts_top, color=bar_color_top, width=0.85) ax_top.bar(value_counts_idx_bot, value_counts_bot, color=bar_color_bottom, width=0.85) ax_top.set(frame_on=False) ax_top.tick_params(axis="x", labelrotation=90) # Summary stats ax_bottom = fig.add_subplot(gs[1:2, count:count + 1]) plt.subplots_adjust(hspace=0.075) ax_bottom.get_yaxis().set_visible(False) ax_bottom.get_xaxis().set_visible(False) ax_bottom.set(frame_on=False) ax_bottom.text( 0, 0, f"Unique values: {n_unique}\n\n" f"Top {lim_top} vals: {sum(value_counts_top)} ({sum(value_counts_top)/data.shape[0]*100:.1f}%)\n" f"Bot {lim_bot} vals: {sum(value_counts_bot)} ({sum(value_counts_bot)/data.shape[0]*100:.1f}%)", transform=ax_bottom.transAxes, color="#111111", fontsize=11, ) # Heatmap color_bot_rgb = to_rgb(bar_color_bottom) color_white = to_rgb("#FFFFFF") color_top_rgb = to_rgb(bar_color_top) cat_plot_cmap = LinearSegmentedColormap.from_list( "cat_plot_cmap", [color_bot_rgb, color_white, color_top_rgb], N=200) ax_hm = fig.add_subplot(gs[2:, :]) sns.heatmap(data, cmap=cat_plot_cmap, cbar=False, vmin=0, vmax=10, ax=ax_hm) ax_hm.set_yticks(np.round(ax_hm.get_yticks()[0::5], -1)) ax_hm.set_yticklabels(ax_hm.get_yticks()) ax_hm.set_xticklabels(ax_hm.get_xticklabels(), horizontalalignment="center", fontweight="light", fontsize="medium") ax_hm.tick_params(length=1, colors="#111111") gs.figure.suptitle("Categorical data plot", x=0.5, y=0.91, fontsize=18, color="#111111") return gs
def train_dev_test_split(data, target, dev_size=0.1, test_size=0.1, stratify=None, random_state=408): """ Split a dataset and a label column into train, dev and test sets. Parameters ---------- data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \ is provided, the index/column information is used to label the plots. target: string, list, np.array or pd.Series, default None Specify target for correlation. E.g. label column to generate only the \ correlations between each feature and the label. dev_size: float, default 0.1 If float, should be between 0.0 and 1.0 and represent the proportion of the \ dataset to include in the dev split. test_size: float, default 0.1 If float, should be between 0.0 and 1.0 and represent the proportion of the \ dataset to include in the test split. stratify: target column, default None If not None, data is split in a stratified fashion, using the input as the \ class labels. random_state: integer, default 408 Random_state is the seed used by the random number generator. Returns ------- tuple: Tuple containing train-dev-test split of inputs. """ # Validate Inputs _validate_input_range(dev_size, "dev_size", 0, 1) _validate_input_range(test_size, "test_size", 0, 1) _validate_input_int(random_state, "random_state") _validate_input_sum_smaller(1, "Dev and test", dev_size, test_size) target_data = [] if isinstance(target, str): target_data = data[target] data = data.drop(target, axis=1) elif isinstance(target, (list, pd.Series, np.ndarray)): target_data = pd.Series(target) X_train, X_dev_test, y_train, y_dev_test = train_test_split( data, target_data, test_size=dev_size + test_size, random_state=random_state, stratify=stratify, ) if (dev_size == 0) or (test_size == 0): return X_train, X_dev_test, y_train, y_dev_test else: X_dev, X_test, y_dev, y_test = train_test_split( X_dev_test, y_dev_test, test_size=test_size / (dev_size + test_size), random_state=random_state, stratify=y_dev_test, ) return X_train, X_dev, X_test, y_train, y_dev, y_test