def main(): if not os.path.exists(OUT_DIR): os.makedirs(OUT_DIR) data, experiment_config = run_experiment(start_locations=[[0.3, 0.3], [0.5, 0.5]], ) # extract data multibody_config = experiment_config['environment']['multibody'] agents = {time: time_data['agents'] for time, time_data in data.items()} fields = {time: time_data['fields'] for time, time_data in data.items()} # agents plot agents_settings = {'agents_key': 'agents'} plot_agents_multigen(data, agents_settings, OUT_DIR, 'agents') # snapshot plot snapshot_data = { 'agents': agents, 'fields': fields, 'config': multibody_config, } snapshot_config = { 'out_dir': OUT_DIR, 'filename': 'agents_snapshots', } plot_snapshots(snapshot_data, snapshot_config) # Colony Metrics Plot embedded_ts = timeseries_from_data(data) colony_metrics_ts = embedded_ts['colony_global'] colony_metrics_ts['time'] = embedded_ts['time'] path_ts = path_timeseries_from_embedded_timeseries(colony_metrics_ts) fig = plot_colony_metrics(path_ts) fig.savefig(os.path.join(OUT_DIR, 'colonies'))
def plot_colony_metrics(self, data, out_dir): embedded_ts = timeseries_from_data(data) colony_metrics_ts = embedded_ts['colony_global'] colony_metrics_ts['time'] = embedded_ts['time'] path_ts = path_timeseries_from_embedded_timeseries(colony_metrics_ts) fig = plot_colony_metrics(path_ts) fig.savefig(os.path.join(out_dir, 'colonies'))
def test_flagella_metabolism(seed=1): random.seed(seed) np.random.seed(seed) # make the compartment and state compartment = FlagellaExpressionMetabolism({'divide': False}) initial_state = get_flagella_metabolism_initial_state() name = compartment.name # run simulation settings = { 'total_time': 60, 'emit_step': COMPARTMENT_TIMESTEP, 'initial_state': initial_state, } timeseries = simulate_compartment_in_experiment(compartment, settings) # remove non-numerical data for timeseries comparison, convert to path_timeseries del timeseries['chromosome'] del timeseries['ribosomes'] del timeseries['dimensions'] del timeseries['boundary']['divide'] del timeseries['fields'] del timeseries['null'] path_timeseries = path_timeseries_from_embedded_timeseries(timeseries) # # save reference timeseries # out_dir = os.path.join(COMPARTMENT_OUT_DIR, name) # if not os.path.exists(out_dir): # os.makedirs(out_dir) # save_flat_timeseries(path_timeseries, out_dir) reference = load_timeseries( os.path.join(REFERENCE_DATA_DIR, name + '.csv')) assert_timeseries_close(path_timeseries, reference)
def format_data_for_colony_metrics(data): embedded_ts = timeseries_from_data(data) colony_metrics_ts = embedded_ts['colony_global'] colony_metrics_ts['time'] = embedded_ts['time'] path_ts = path_timeseries_from_embedded_timeseries(colony_metrics_ts) return path_ts
def plot_simulation_output( timeseries_raw, settings: Optional[Dict[str, Any]] = None, out_dir=None, filename='simulation', ): ''' Plot simulation output, with rows organized into separate columns. Arguments:: timeseries (dict): This can be obtained from simulation output with convert_to_timeseries() settings (dict): Accepts the following keys: * **column_width** (:py:class:`int`): the width (inches) of each column in the figure * **max_rows** (:py:class:`int`): ports with more states than this number of states get wrapped into a new column * **remove_zeros** (:py:class:`bool`): if True, timeseries with all zeros get removed * **remove_flat** (:py:class:`bool`): if True, timeseries with all the same value get removed * **remove_first_timestep** (:py:class:`bool`): if True, skips the first timestep * **skip_ports** (:py:class:`list`): entire ports that won't be plotted * **show_state** (:py:class:`list`): with ``[('port_id', 'state_id')]`` for all states that will be highlighted, even if they are otherwise to be removed TODO: Obsolete? ''' int_or_float = (int, np.int32, np.int64, float, np.float32, np.float64) settings = settings or {} plot_fontsize = 8 plt.rc('font', size=plot_fontsize) plt.rc('axes', titlesize=plot_fontsize) # get settings column_width = settings.get('column_width', 3) max_rows = settings.get('max_rows', 25) remove_zeros = settings.get('remove_zeros', True) remove_flat = settings.get('remove_flat', False) skip_ports = settings.get('skip_ports', []) remove_first_timestep = settings.get('remove_first_timestep', False) # make a flat 'path' timeseries, with keys being path top_level = list(timeseries_raw.keys()) timeseries = path_timeseries_from_embedded_timeseries(timeseries_raw) time_vec = timeseries.pop('time') if remove_first_timestep: time_vec = time_vec[1:] # remove select states from timeseries removed_states = set() for path, series in timeseries.items(): if path[0] in skip_ports: removed_states.add(path) elif remove_flat: if series.count(series[0]) == len(series): removed_states.add(path) elif remove_zeros: if all(v == 0 for v in series): removed_states.add(path) for path in removed_states: del timeseries[path] # get figure columns # get length of each top-level port port_lengths = {} for path in timeseries.keys(): if path[0] in top_level: if path[0] not in port_lengths: port_lengths[path[0]] = 0 port_lengths[path[0]] += 1 n_data = [length for port, length in port_lengths.items() if length > 0] columns = [] for n_states in n_data: new_cols = n_states / max_rows if new_cols > 1: for col in range(int(new_cols)): columns.append(max_rows) mod_states = n_states % max_rows if mod_states > 0: columns.append(mod_states) else: columns.append(n_states) # make figure and plot n_cols = len(columns) n_rows = max(columns) fig = plt.figure(figsize=(n_cols * column_width, n_rows * column_width / 3)) grid = plt.GridSpec(n_rows, n_cols) row_idx = 0 col_idx = 0 for port in port_lengths.keys(): # get this port's timeseries port_timeseries = {} for path, ts in timeseries.items(): if path[0] is port: next_path = path[1:] if any(isinstance(item, tuple) for item in next_path): next_path = tuple([ item[0] if isinstance(item, tuple) else item for item in next_path ]) port_timeseries[next_path] = ts for state_id, series in sorted(port_timeseries.items()): if remove_first_timestep: series = series[1:] # not enough data points -- this state likely did not exist throughout the entire simulation if len(series) != len(time_vec): continue ax = fig.add_subplot(grid[row_idx, col_idx]) # grid is (row, column) if not all(isinstance(state, int_or_float) for state in series): # check if series is a list of ints or floats ax.title.set_text( str(port) + ': ' + str(state_id) + ' (non numeric)') else: # plot line at zero if series crosses the zero line if any(x == 0.0 for x in series) or (any(x < 0.0 for x in series) and any(x > 0.0 for x in series)): zero_line = [0 for t in time_vec] ax.plot(time_vec, zero_line, 'k--') # plot the series ax.plot(time_vec, series) if isinstance(state_id, tuple): # new line for each store state_id = '\n'.join(state_id) ax.title.set_text(str(port) + '\n' + str(state_id)) if row_idx == columns[col_idx] - 1: # if last row of column set_axes(ax, True) ax.set_xlabel('time (s)') row_idx = 0 col_idx += 1 else: set_axes(ax) row_idx += 1 ax.set_xlim([time_vec[0], time_vec[-1]]) if out_dir: plt.subplots_adjust(wspace=column_width / 3, hspace=column_width / 3) _save_fig_to_dir(fig, filename, out_dir) return fig
def plot_simulation_output(timeseries_raw, settings={}, out_dir='out', filename='simulation'): ''' plot simulation output, with rows organized into separate columns. Requires: - timeseries (dict). This can be obtained from simulation output with convert_to_timeseries() - settings (dict) with: { 'max_rows': (int) ports with more states than this number of states get wrapped into a new column 'remove_zeros': (bool) if True, timeseries with all zeros get removed 'remove_flat': (bool) if True, timeseries with all the same value get removed 'remove_first_timestep': (bool) if True, skips the first timestep 'skip_ports': (list) entire ports that won't be plotted 'show_state': (list) with [('port_id', 'state_id')] for all states that will be highlighted, even if they are otherwise to be removed } ''' plot_fontsize = 8 plt.rc('font', size=plot_fontsize) plt.rc('axes', titlesize=plot_fontsize) skip_keys = ['time'] # get settings max_rows = settings.get('max_rows', 25) remove_zeros = settings.get('remove_zeros', True) remove_flat = settings.get('remove_flat', False) skip_ports = settings.get('skip_ports', []) remove_first_timestep = settings.get('remove_first_timestep', False) # make a flat 'path' timeseries, with keys being path top_level = list(timeseries_raw.keys()) timeseries = path_timeseries_from_embedded_timeseries(timeseries_raw) time_vec = timeseries.pop('time') if remove_first_timestep: time_vec = time_vec[1:] # remove select states from timeseries removed_states = set() for path, series in timeseries.items(): if path[0] in skip_ports: removed_states.add(path) elif remove_flat: if series.count(series[0]) == len(series): removed_states.add(path) elif remove_zeros: if all(v == 0 for v in series): removed_states.add(path) for path in removed_states: del timeseries[path] ## get figure columns # get length of each top-level port port_lengths = {} for path in timeseries.keys(): if path[0] in top_level: if path[0] not in port_lengths: port_lengths[path[0]] = 0 port_lengths[path[0]] += 1 n_data = [length for port, length in port_lengths.items() if length > 0] columns = [] for n_states in n_data: new_cols = n_states / max_rows if new_cols > 1: for col in range(int(new_cols)): columns.append(max_rows) mod_states = n_states % max_rows if mod_states > 0: columns.append(mod_states) else: columns.append(n_states) # make figure and plot n_cols = len(columns) n_rows = max(columns) fig = plt.figure(figsize=(n_cols * 3, n_rows * 1)) grid = plt.GridSpec(n_rows, n_cols) row_idx = 0 col_idx = 0 for port in port_lengths.keys(): # get this port's states port_timeseries = {path[1:]: ts for path, ts in timeseries.items() if path[0] is port} for state_id, series in sorted(port_timeseries.items()): if remove_first_timestep: series = series[1:] # not enough data points -- this state likely did not exist throughout the entire simulation if len(series) != len(time_vec): continue ax = fig.add_subplot(grid[row_idx, col_idx]) # grid is (row, column) if not all(isinstance(state, (int, float, np.int64, np.int32)) for state in series): # check if series is a list of ints or floats ax.title.set_text(str(port) + ': ' + str(state_id) + ' (non numeric)') else: # plot line at zero if series crosses the zero line if any(x == 0.0 for x in series) or (any(x < 0.0 for x in series) and any(x > 0.0 for x in series)): zero_line = [0 for t in time_vec] ax.plot(time_vec, zero_line, 'k--') # plot the series ax.plot(time_vec, series) ax.title.set_text(str(port) + ': ' + str(state_id)) if row_idx == columns[col_idx]-1: # if last row of column set_axes(ax, True) ax.set_xlabel('time (s)') row_idx = 0 col_idx += 1 else: set_axes(ax) row_idx += 1 ax.set_xlim([time_vec[0], time_vec[-1]]) # save figure fig_path = os.path.join(out_dir, filename) plt.subplots_adjust(wspace=0.8, hspace=1.0) plt.savefig(fig_path, bbox_inches='tight')