def test_update_dictionary(): """Tests for update_dictionary""" # no conflicts, takes union default_dict = {"a": 1, "d": 4} overwrite_dict = {"b": 2, "c": 3} merged_dict = update_dictionary(default_dict, overwrite_dict=overwrite_dict) assert merged_dict == {"a": 1, "b": 2, "c": 3, "d": 4} # overwrite takes precedence in conflicts default_dict = {"a": 1, "d": 4} overwrite_dict = {"a": 2, "c": 3} merged_dict = update_dictionary(default_dict, overwrite_dict=overwrite_dict) assert merged_dict == {"a": 2, "c": 3, "d": 4} # overwrite can be None or {} default_dict = {"a": 1, "b": 4} merged_dict = update_dictionary(default_dict, overwrite_dict=None) assert merged_dict == default_dict merged_dict = update_dictionary(default_dict, overwrite_dict={}) assert merged_dict == default_dict default_dict = {"a": 1, "b": 4, "c": 5} overwrite_dict = {"a": 1, "b": 4} update_dictionary(default_dict, overwrite_dict=overwrite_dict, allow_unknown_keys=False) # overwrite cannot be a strict superset of default with pytest.raises(ValueError, match=r"Unexpected key\(s\) found"): overwrite_dict = {"a": 1, "b": 4, "d": 1} update_dictionary(default_dict, overwrite_dict=overwrite_dict, allow_unknown_keys=False)
def fit(self, X, y=None): """Updates `self.impute_params`. Parameters ---------- X : `pandas.DataFrame` Training input data. e.g. each column is a timeseries. Columns are expected to be numeric. y : None There is no need of a target in a transformer, yet the pipeline API requires this parameter. Returns ------- self : object Returns self. """ assert isinstance(X, pd.DataFrame) self._is_fitted = True # sets default parameters if self.impute_algorithm is not None: default_params = DEFAULT_PARAMS.get(self.impute_algorithm, {}) self.impute_params = update_dictionary(default_params, overwrite_dict=self.impute_params) return self
def fit(self, X, y=None, time_col=cst.TIME_COL, value_col=cst.VALUE_COL, **fit_params): """Fits ``Silverkite`` forecast model. Parameters ---------- X: `pandas.DataFrame` Input timeseries, with timestamp column, value column, and any additional regressors. The value column is the response, included in ``X`` to allow transformation by `sklearn.pipeline`. y: ignored The original timeseries values, ignored. (The ``y`` for fitting is included in ``X``). time_col: `str` Time column name in ``X``. value_col: `str` Value column name in ``X``. fit_params: `dict` additional parameters for null model. Returns ------- self : self Fitted model is stored in ``self.model_dict``. """ # Initializes `fit_algorithm_dict` with default values. # This cannot be done in __init__ to remain compatible # with sklearn grid search. default_fit_algorithm_dict = { "fit_algorithm": "ridge", "fit_algorithm_params": None } self.fit_algorithm_dict = update_dictionary( default_fit_algorithm_dict, overwrite_dict=self.fit_algorithm_dict) # Fits null model super().fit(X=X, y=y, time_col=time_col, value_col=value_col, **fit_params) self.model_dict = self.silverkite.forecast_simple( df=X, time_col=time_col, value_col=value_col, time_properties=self.time_properties, freq=self.freq, forecast_horizon=self.forecast_horizon, origin_for_time_vars=self.origin_for_time_vars, train_test_thresh=self.train_test_thresh, training_fraction=self.training_fraction, fit_algorithm=self.fit_algorithm_dict["fit_algorithm"], fit_algorithm_params=self. fit_algorithm_dict["fit_algorithm_params"], holidays_to_model_separately=self.holidays_to_model_separately, holiday_lookup_countries=self.holiday_lookup_countries, holiday_pre_num_days=self.holiday_pre_num_days, holiday_post_num_days=self.holiday_post_num_days, holiday_pre_post_num_dict=self.holiday_pre_post_num_dict, daily_event_df_dict=self.daily_event_df_dict, changepoints_dict=self.changepoints_dict, yearly_seasonality=self.yearly_seasonality, quarterly_seasonality=self.quarterly_seasonality, monthly_seasonality=self.monthly_seasonality, weekly_seasonality=self.weekly_seasonality, daily_seasonality=self.daily_seasonality, max_daily_seas_interaction_order=self. max_daily_seas_interaction_order, max_weekly_seas_interaction_order=self. max_weekly_seas_interaction_order, autoreg_dict=self.autoreg_dict, seasonality_changepoints_dict=self.seasonality_changepoints_dict, min_admissible_value=self.min_admissible_value, max_admissible_value=self.max_admissible_value, uncertainty_dict=self.uncertainty_dict, growth_term=self.growth_term, regressor_cols=self.regressor_cols, feature_sets_enabled=self.feature_sets_enabled, extra_pred_cols=self.extra_pred_cols, regression_weight_col=self.regression_weight_col, simulation_based=self.simulation_based) # Sets attributes based on ``self.model_dict`` super().finish_fit() return self
def apply_prophet_model_components_defaults(self, model_components=None, time_properties=None): """Sets default values for ``model_components``. Called by ``get_hyperparameter_grid`` after ``time_properties` is defined. Requires ``time_properties`` as well as ``model_components`` so we do not simply override `~greykite.framework.templates.forecast_config_defaults.ForecastConfigDefaults.apply_model_components_defaults`. Parameters ---------- model_components : :class:`~greykite.framework.templates.autogen.forecast_config.ModelComponentsParam` or None, default None Configuration of model growth, seasonality, events, etc. See the docstring of this class for details. time_properties : `dict` [`str`, `any`] or None, default None Time properties dictionary (likely produced by `~greykite.common.time_properties_forecast.get_forecast_time_properties`) with keys: ``"period"`` : `int` Period of each observation (i.e. minimum time between observations, in seconds). ``"simple_freq"`` : `SimpleTimeFrequencyEnum` ``SimpleTimeFrequencyEnum`` member corresponding to data frequency. ``"num_training_points"`` : `int` Number of observations for training. ``"num_training_days"`` : `int` Number of days for training. ``"start_year"`` : `int` Start year of the training period. ``"end_year"`` : `int` End year of the forecast period. ``"origin_for_time_vars"`` : `float` Continuous time representation of the first date in ``df``. If None, start_year is set to 2015 and end_year to 2030. Returns ------- model_components : :class:`~greykite.framework.templates.autogen.forecast_config.ModelComponentsParam` The provided ``model_components`` with default values set """ if model_components is None: model_components = ModelComponentsParam() else: # makes a copy to avoid mutating input model_components = dataclasses.replace(model_components) if time_properties is None: time_properties = { "start_year": 2015, "end_year": 2030, } # seasonality default_seasonality = { "seasonality_mode": ["additive"], "seasonality_prior_scale": [10.0], "yearly_seasonality": ['auto'], "weekly_seasonality": ['auto'], "daily_seasonality": ['auto'], "add_seasonality_dict": [None] } # If seasonality params are not provided, uses default params. Otherwise, prefers provided params. # `allow_unknown_keys=False` requires `model_components.seasonality` keys to be a subset of # `default_seasonality` keys. model_components.seasonality = update_dictionary( default_dict=default_seasonality, overwrite_dict=model_components.seasonality, allow_unknown_keys=False) # growth default_growth = {"growth_term": ["linear"]} model_components.growth = update_dictionary( default_dict=default_growth, overwrite_dict=model_components.growth, allow_unknown_keys=False) # events default_events = { "holiday_lookup_countries": "auto", # see `get_prophet_holidays` for defaults "holiday_pre_num_days": [2], "holiday_post_num_days": [2], "start_year": time_properties["start_year"], "end_year": time_properties["end_year"], "holidays_prior_scale": [10.0] } model_components.events = update_dictionary( default_dict=default_events, overwrite_dict=model_components.events, allow_unknown_keys=False) # Creates events dictionary for prophet estimator # Expands the range of holiday years by 1 year on each end, to ensure we have coverage of most relevant holidays. year_list = list( range(model_components.events["start_year"] - 1, model_components.events["end_year"] + 2)) # Currently we support only one set of holiday_lookup_countries, holiday_pre_num_days and holiday_post_num_days. # Shows a warning if user supplies >1 set. if len(model_components.events["holiday_pre_num_days"]) > 1: warnings.warn( f"`events['holiday_pre_num_days']` list has more than 1 element. We currently support only 1 element. " f"Using {model_components.events['holiday_pre_num_days'][0]}.") if len(model_components.events["holiday_post_num_days"]) > 1: warnings.warn( f"`events['holiday_post_num_days']` list has more than 1 element. We currently support only 1 element. " f"Using {model_components.events['holiday_post_num_days'][0]}." ) # If events["holiday_lookup_countries"] has multiple options, picks the first option if (model_components.events["holiday_lookup_countries"] is not None and model_components.events["holiday_lookup_countries"] != "auto"): if len(model_components.events["holiday_lookup_countries"]) > 1: # There are multiple elements if (any( isinstance(x, list) for x in model_components.events["holiday_lookup_countries"]) or None in model_components.events["holiday_lookup_countries"] or "auto" in model_components.events["holiday_lookup_countries"]): # Not a flat list of country names warnings.warn( f"`events['holiday_lookup_countries']` contains multiple options. " f"We currently support only 1 option. Using {model_components.events['holiday_lookup_countries'][0]}." ) model_components.events[ "holiday_lookup_countries"] = model_components.events[ "holiday_lookup_countries"][0] elif isinstance( model_components.events["holiday_lookup_countries"][0], (list, tuple)): # There's only one element, and it's a list of countries model_components.events[ "holiday_lookup_countries"] = model_components.events[ "holiday_lookup_countries"][0] model_components.events = { "holidays_df": self.get_prophet_holidays( year_list=year_list, countries=model_components.events["holiday_lookup_countries"], # holiday effect is modeled from "holiday_pre_num_days" days before # to "holiday_post_num_days" days after the holiday lower_window=-model_components.events["holiday_pre_num_days"] [0], # Prophet expects a negative value for `lower_window` upper_window=model_components.events["holiday_post_num_days"] [0]), "holidays_prior_scale": model_components.events["holidays_prior_scale"] } # changepoints_dict default_changepoints = { "changepoint_prior_scale": [0.05], "changepoints": [None], "n_changepoints": [25], "changepoint_range": [0.8] } model_components.changepoints = update_dictionary( default_dict=default_changepoints, overwrite_dict=model_components.changepoints, allow_unknown_keys=False) # uncertainty default_uncertainty = { "mcmc_samples": [0], "uncertainty_samples": [1000] } model_components.uncertainty = update_dictionary( default_dict=default_uncertainty, overwrite_dict=model_components.uncertainty, allow_unknown_keys=False) # regressors default_regressors = {"add_regressor_dict": [None]} model_components.regressors = update_dictionary( default_dict=default_regressors, overwrite_dict=model_components.regressors, allow_unknown_keys=False) # there are no custom parameters for Prophet # sets to {} if None, for each item if # `model_components.hyperparameter_override` is a list of dictionaries model_components.hyperparameter_override = update_dictionaries( {}, overwrite_dicts=model_components.hyperparameter_override) return model_components
def apply_default_model_components(model_components=None, time_properties=None): """Sets default values for ``model_components``. Parameters ---------- model_components : :class:`~greykite.framework.templates.autogen.forecast_config.ModelComponentsParam` or None, default None Configuration of model growth, seasonality, events, etc. See :func:`~greykite.framework.templates.silverkite_templates.silverkite_template` for details. time_properties : `dict` [`str`, `any`] or None, default None Time properties dictionary (likely produced by `~greykite.common.time_properties_forecast.get_forecast_time_properties`) with keys: ``"period"`` : `int` Period of each observation (i.e. minimum time between observations, in seconds). ``"simple_freq"`` : `SimpleTimeFrequencyEnum` ``SimpleTimeFrequencyEnum`` member corresponding to data frequency. ``"num_training_points"`` : `int` Number of observations for training. ``"num_training_days"`` : `int` Number of days for training. ``"start_year"`` : `int` Start year of the training period. ``"end_year"`` : `int` End year of the forecast period. ``"origin_for_time_vars"`` : `float` Continuous time representation of the first date in ``df``. Returns ------- model_components : :class:`~greykite.framework.templates.autogen.forecast_config.ModelComponentsParam` The provided ``model_components`` with default values set """ if model_components is None: model_components = ModelComponentsParam() else: # makes a copy to avoid mutating input model_components = dataclasses.replace(model_components) # sets default values default_seasonality = { "fs_components_df": [ pd.DataFrame({ "name": ["tod", "tow", "tom", "toq", "toy"], "period": [24.0, 7.0, 1.0, 1.0, 1.0], "order": [3, 3, 1, 1, 5], "seas_names": ["daily", "weekly", "monthly", "quarterly", "yearly"] }) ], } model_components.seasonality = update_dictionary( default_seasonality, overwrite_dict=model_components.seasonality, allow_unknown_keys=False) # model_components.growth must be empty. # Pass growth terms via `extra_pred_cols` instead. default_growth = {} model_components.growth = update_dictionary( default_growth, overwrite_dict=model_components.growth, allow_unknown_keys=False) default_events = { "daily_event_df_dict": [None], } model_components.events = update_dictionary( default_events, overwrite_dict=model_components.events, allow_unknown_keys=False) default_changepoints = { "changepoints_dict": [None], "seasonality_changepoints_dict": [None], # Not allowed, to prevent leaking future information # into the past. Pass `changepoints_dict` with method="auto" for # automatic detection. # "changepoint_detector": [None], } model_components.changepoints = update_dictionary( default_changepoints, overwrite_dict=model_components.changepoints, allow_unknown_keys=False) default_autoregression = { "autoreg_dict": [None], } model_components.autoregression = update_dictionary( default_autoregression, overwrite_dict=model_components.autoregression, allow_unknown_keys=False) default_regressors = {} model_components.regressors = update_dictionary( default_regressors, overwrite_dict=model_components.regressors, allow_unknown_keys=False) default_lagged_regressors = { "lagged_regressor_dict": [None], } model_components.lagged_regressors = update_dictionary( default_lagged_regressors, overwrite_dict=model_components.lagged_regressors, allow_unknown_keys=False) default_uncertainty = { "uncertainty_dict": [None], } model_components.uncertainty = update_dictionary( default_uncertainty, overwrite_dict=model_components.uncertainty, allow_unknown_keys=False) if time_properties is not None: origin_for_time_vars = time_properties.get("origin_for_time_vars") else: origin_for_time_vars = None default_custom = { "silverkite": [SilverkiteForecast()], # NB: sklearn creates a copy in grid search "silverkite_diagnostics": [SilverkiteDiagnostics()], # The same origin for every split, based on start year of full dataset. # To use first date of each training split, set to `None` in model_components. "origin_for_time_vars": [origin_for_time_vars], "extra_pred_cols": ["ct1"], # linear growth "fit_algorithm_dict": [{ "fit_algorithm": "linear", "fit_algorithm_params": None, }], "min_admissible_value": [None], "max_admissible_value": [None], } model_components.custom = update_dictionary( default_custom, overwrite_dict=model_components.custom, allow_unknown_keys=False) # sets to {} if None, for each item if # `model_components.hyperparameter_override` is a list of dictionaries model_components.hyperparameter_override = update_dictionaries( {}, overwrite_dicts=model_components.hyperparameter_override) return model_components
def plot_multivariate(df, x_col, y_col_style_dict="plotly", default_color="rgba(0, 145, 202, 1.0)", xlabel=None, ylabel=cst.VALUE_COL, title=None, showlegend=True): """Plots one or more lines against the same x-axis values. Parameters ---------- df : `pandas.DataFrame` Data frame with ``x_col`` and columns named by the keys in ``y_col_style_dict``. x_col: `str` Which column to plot on the x-axis. y_col_style_dict: `dict` [`str`, `dict` or None] or "plotly" or "auto" or "auto-fill", default "plotly" The column(s) to plot on the y-axis, and how to style them. If a dictionary: - key : `str` column name in ``df`` - value : `dict` or None Optional styling options, passed as kwargs to `go.Scatter`. If None, uses the default: line labeled by the column name. See reference page for `plotly.graph_objs.Scatter` for options (e.g. color, mode, width/size, opacity). https://plotly.com/python/reference/#scatter. If a string, plots all columns in ``df`` besides ``x_col`` against ``x_col``: - "plotly": plot lines with default plotly styling - "auto": plot lines with color ``default_color``, sorted by value (ascending) - "auto-fill": plot lines with color ``default_color``, sorted by value (ascending), and fills between lines default_color: `str`, default "rgba(0, 145, 202, 1.0)" (blue) Default line color when ``y_col_style_dict`` is one of "auto", "auto-fill". xlabel : `str` or None, default None x-axis label. If None, default is ``x_col``. ylabel : `str` or None, default ``VALUE_COL`` y-axis label title : `str` or None, default None Plot title. If None, default is based on axis labels. showlegend : `bool`, default True Whether to show the legend. Returns ------- fig : `plotly.graph_objs.Figure` Interactive plotly graph of one or more columns in ``df`` against ``x_col``. See `~greykite.common.viz.timeseries_plotting.plot_forecast_vs_actual` return value for how to plot the figure and add customization. """ if xlabel is None: xlabel = x_col if title is None and ylabel is not None: title = f"{ylabel} vs {xlabel}" auto_style = {"line": {"color": default_color}} if y_col_style_dict == "plotly": # Uses plotly default style y_col_style_dict = {col: None for col in df.columns if col != x_col} elif y_col_style_dict in ["auto", "auto-fill"]: # Columns ordered from low to high means = df.drop(columns=x_col).mean() column_order = list(means.sort_values().index) if y_col_style_dict == "auto": # Lines with color `default_color` y_col_style_dict = {col: auto_style for col in column_order} elif y_col_style_dict == "auto-fill": # Lines with color `default_color`, with fill between lines y_col_style_dict = {column_order[0]: auto_style} y_col_style_dict.update({ col: { "line": { "color": default_color }, "fill": "tonexty" } for col in column_order[1:] }) data = [] default_style = dict(mode="lines") for column, style_dict in y_col_style_dict.items(): # By default, column name in ``df`` is used to label the line default_col_style = update_dictionary(default_style, overwrite_dict={"name": column}) # User can overwrite any of the default values, or remove them by setting key value to None style_dict = update_dictionary(default_col_style, overwrite_dict=style_dict) line = go.Scatter(x=df[x_col], y=df[column], **style_dict) data.append(line) layout = go.Layout( xaxis=dict(title=xlabel), yaxis=dict(title=ylabel), title=title, showlegend=showlegend, legend={'traceorder': 'reversed' } # Matches the order of ``y_col_style_dict`` (bottom to top) ) fig = go.Figure(data=data, layout=layout) return fig
def plot_multivariate_grouped(df, x_col, y_col_style_dict, grouping_x_col, grouping_x_col_values, grouping_y_col_style_dict, colors=DEFAULT_PLOTLY_COLORS, xlabel=None, ylabel=cst.VALUE_COL, title=None, showlegend=True): """Plots multiple lines against the same x-axis values. The lines can partially share the x-axis values. See parameter descriptions for a running example. Parameters ---------- df : `pandas.DataFrame` Data frame with ``x_col`` and columns named by the keys in ``y_col_style_dict``, ``grouping_x_col``, ``grouping_y_col_style_dict``. For example:: df = pd.DataFrame({ time: [dt(2018, 1, 1), dt(2018, 1, 2), dt(2018, 1, 3)], "y1": [8.5, 2.0, 3.0], "y2": [1.4, 2.1, 3.4], "y3": [4.2, 3.1, 3.0], "y4": [0, 1, 2], "y5": [10, 9, 8], "group": [1, 2, 1], }) This will be our running example. x_col: `str` Which column to plot on the x-axis. "time" in our example. y_col_style_dict: `dict` [`str`, `dict` or None] The column(s) to plot on the y-axis, and how to style them. These columns are plotted against the complete x-axis. - key : `str` column name in ``df`` - value : `dict` or None Optional styling options, passed as kwargs to `go.Scatter`. If None, uses the default: line labeled by the column name. If line color is not given, it is added according to ``colors``. See reference page for `plotly.graph_objs.Scatter` for options (e.g. color, mode, width/size, opacity). https://plotly.com/python/reference/#scatter. For example:: y_col_style_dict={ "y1": { "name": "y1_name", "legendgroup": "one", "mode": "markers", "line": None # Remove line params since we use mode="markers" }, "y2": None, } The function will add a line color to "y1" and "y2" based on the ``colors`` parameter. It will also add a name to "y2", since none was given. The "name" of "y1" will be preserved. The output ``fig`` will have one line each for each of "y1" and "y2", each plot against the entire "time" column. grouping_x_col: `str` Which column to use to group columns in ``grouping_y_col_style_dict``. "group" in our example. grouping_x_col_values: `list` [`int`] or None Which values to use for grouping. If None, uses all the unique values in ``df`` [``grouping_x_col``]. In our example, specifying ``grouping_x_col_values == [1, 2]`` would plot separate lines corresponding to ``group==1`` and ``group==2``. grouping_y_col_style_dict: `dict` [`str`, `dict` or None] The column(s) to plot on the y-axis, and how to style them. These columns are plotted against partial x-axis. For each ``grouping_x_col_values`` an element in this dictionary produces one line. - key : `str` column name in ``df`` - value : `dict` or None Optional styling options, passed as kwargs to `go.Scatter`. If None, uses the default: line labeled by the ``grouping_x_col_values``, ``grouping_x_col`` and column name. If a name is given, it is augmented with the ``grouping_x_col_values``. If line color is not given, it is added according to ``colors``. All the lines sharing same ``grouping_x_col_values`` have the same color. See reference page for `plotly.graph_objs.Scatter` for options (e.g. color, mode, width/size, opacity). https://plotly.com/python/reference/#scatter. For example:: grouping_y_col_style_dict={ "y3": { "line": { "color": "blue" } }, "y4": { "name": "y4_name", "line": { "width": 2, "dash": "dot" } }, "y5": None, } The function will add a line color to "y4" and "y5" based on the ``colors`` parameter. The line color of "y3" will be "blue" as specified. We also preserve the given line properties of "y4". ` The function adds a name to "y3" and "y5", since none was given. The given "name" of "y4" will be augmented with ``grouping_x_col_values``. Each element of ``grouping_y_col_style_dict`` gets one line for each ``grouping_x_col_values``. In our example, there will be 2 lines corresponding to "y3", named "1_y3" and "2_y3". "1_y3" is plotted against "time = [dt(2018, 1, 1), dt(2018, 1, 3)]", corresponding to ``group==1``. "2_y3" is plotted against "time = [dt(2018, 1, 2)", corresponding to ``group==2``. colors: [`str`, `list` [`str`]], default ``DEFAULT_PLOTLY_COLORS`` Which colors to use to build a color palette for plotting. This can be a list of RGB colors or a `str` from ``PLOTLY_SCALES``. Required number of colors equals sum of the length of ``y_col_style_dict`` and length of ``grouping_x_col_values``. See `~greykite.common.viz.colors_utils.get_color_palette` for details. xlabel : `str` or None, default None x-axis label. If None, default is ``x_col``. ylabel : `str` or None, default ``VALUE_COL`` y-axis label title : `str` or None, default None Plot title. If None, default is based on axis labels. showlegend : `bool`, default True Whether to show the legend. Returns ------- fig : `plotly.graph_objs.Figure` Interactive plotly graph of one or more columns in ``df`` against ``x_col``. See `~greykite.common.viz.timeseries_plotting.plot_forecast_vs_actual` return value for how to plot the figure and add customization. """ available_grouping_x_col_values = np.unique(df[grouping_x_col]) if grouping_x_col_values is None: grouping_x_col_values = available_grouping_x_col_values else: missing_grouping_x_col_values = set(grouping_x_col_values) - set( available_grouping_x_col_values) if len(missing_grouping_x_col_values) > 0: raise ValueError( f"Following 'grouping_x_col_values' are missing in '{grouping_x_col}' column: " f"{missing_grouping_x_col_values}") # Chooses the color palette n_color = len(y_col_style_dict) + len(grouping_x_col_values) color_palette = get_color_palette(num=n_color, colors=colors) # Updates colors for y_col_style_dict if it is not specified for color_num, (column, style_dict) in enumerate(y_col_style_dict.items()): if style_dict is None: style_dict = {} default_color = {"color": color_palette[color_num]} style_dict["line"] = update_dictionary( default_color, overwrite_dict=style_dict.get("line")) y_col_style_dict[column] = style_dict # Standardizes dataset for the next figure df_standardized = df.copy().drop_duplicates(subset=[x_col]).sort_values( by=x_col) # This figure plots the whole xaxis vs yaxis values fig = plot_multivariate(df=df_standardized, x_col=x_col, y_col_style_dict=y_col_style_dict, xlabel=xlabel, ylabel=ylabel, title=title, showlegend=showlegend) data = fig.data layout = fig.layout # These figures plot the sliced xaxis vs yaxis values for color_num, grouping_x_col_value in enumerate(grouping_x_col_values, len(y_col_style_dict)): default_color = {"color": color_palette[color_num]} sliced_y_col_style_dict = grouping_y_col_style_dict.copy() for column, style_dict in sliced_y_col_style_dict.items(): # Updates colors if it is not specified if style_dict is None: style_dict = {} line_dict = update_dictionary( default_color, overwrite_dict=style_dict.get("line")) # Augments names with grouping_x_col_value name = style_dict.get("name") if name is None: updated_name = f"{grouping_x_col_value}_{grouping_x_col}_{column}" else: updated_name = f"{grouping_x_col_value}_{name}" overwrite_dict = {"name": updated_name, "line": line_dict} style_dict = update_dictionary(style_dict, overwrite_dict=overwrite_dict) sliced_y_col_style_dict[column] = style_dict df_sliced = df[df[grouping_x_col] == grouping_x_col_value] fig = plot_multivariate(df=df_sliced, x_col=x_col, y_col_style_dict=sliced_y_col_style_dict) data = data + fig.data fig = go.Figure(data=data, layout=layout) return fig
def fit( self, X, y=None, time_col=cst.TIME_COL, value_col=cst.VALUE_COL, **fit_params): """Fits ``Silverkite`` forecast model. Parameters ---------- X: `pandas.DataFrame` Input timeseries, with timestamp column, value column, and any additional regressors. The value column is the response, included in ``X`` to allow transformation by `sklearn.pipeline`. y: ignored The original timeseries values, ignored. (The ``y`` for fitting is included in ``X``). time_col: `str` Time column name in ``X``. value_col: `str` Value column name in ``X``. fit_params: `dict` additional parameters for null model. """ # Initializes `fit_algorithm_dict` with default values. # This cannot be done in __init__ to remain compatible # with sklearn grid search. default_fit_algorithm_dict = { "fit_algorithm": "linear", "fit_algorithm_params": None} self.fit_algorithm_dict = update_dictionary( default_fit_algorithm_dict, overwrite_dict=self.fit_algorithm_dict) # fits null model super().fit( X=X, y=y, time_col=time_col, value_col=value_col, **fit_params) self.model_dict = self.silverkite.forecast( df=X, time_col=time_col, value_col=value_col, origin_for_time_vars=self.origin_for_time_vars, extra_pred_cols=self.extra_pred_cols, train_test_thresh=self.train_test_thresh, training_fraction=self.training_fraction, fit_algorithm=self.fit_algorithm_dict["fit_algorithm"], fit_algorithm_params=self.fit_algorithm_dict["fit_algorithm_params"], daily_event_df_dict=self.daily_event_df_dict, fs_components_df=self.fs_components_df, autoreg_dict=self.autoreg_dict, lagged_regressor_dict=self.lagged_regressor_dict, changepoints_dict=self.changepoints_dict, seasonality_changepoints_dict=self.seasonality_changepoints_dict, changepoint_detector=self.changepoint_detector, min_admissible_value=self.min_admissible_value, max_admissible_value=self.max_admissible_value, uncertainty_dict=self.uncertainty_dict, normalize_method=self.normalize_method, adjust_anomalous_dict=self.adjust_anomalous_dict, impute_dict=self.impute_dict, regression_weight_col=self.regression_weight_col, forecast_horizon=self.forecast_horizon, simulation_based=self.simulation_based) # sets attributes based on ``self.model_dict`` super().finish_fit() return self