def get_intermediate_plot(study: Study,topx=10,num_trials_threshold=30,color_scale=None) -> "go.Figure": layout = go.Layout( title="Intermediate Values Plot", xaxis={"title": "Step"}, yaxis={"title": "Intermediate Value"}, ) #this value will determine which trials we want df=study.trials_dataframe() v=df.sort_values(by=["value"],ascending=True).reset_index().iloc[topx]["value"] target_state = [TrialState.PRUNED, TrialState.COMPLETE, TrialState.RUNNING] trials = [trial for trial in study.trials if (trial.state in target_state) and (trial.value and trial.value<v) and len(trial.intermediate_values)<num_trials_threshold] if len(trials) == 0: _logger.warning("Study instance does not contain trials.") return go.Figure(data=[], layout=layout) traces = [] for i,trial in enumerate(trials): if trial.intermediate_values: sorted_intermediate_values = sorted(trial.intermediate_values.items()) trace = go.Scatter( x=tuple((x for x, _ in sorted_intermediate_values)), y=tuple((y for _, y in sorted_intermediate_values)), mode="lines+markers", marker={"maxdisplayed": 10}, marker_symbol=i if i<4 else i*2, marker_size=10, marker_color=color_scale[i], name="Trial{}".format(trial.number), ) traces.append(trace) if not traces: _logger.warning( "You need to set up the pruning feature to utilize `plot_intermediate_values()`" ) return go.Figure(data=[], layout=layout) figure = go.Figure(data=traces, layout=layout,) return figure
def create_log_file(study: Study, log_file: str): """ Create Log File =============== Allows to create a log file containing a pandas `DataFrame` of trials in the `Study`. Parameters ---------- study : optuna.study.Study Set of `Trial` objects deriving from an Ask-and-Tell interface. log_file : str Name of file reporting the study results. Returns ------- None """ df = study.trials_dataframe(attrs=('number', 'params', 'value', 'state')) with open(log_file, 'w') as file: print(df.to_string(), file=file)