class GPSTrainingGUI(object): def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config[ 'initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox( self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D( self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer( self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread( target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start() # GPS Training functions def request_stop(self, event=None): self.request_mode('stop') def request_reset(self, event=None): self.request_mode('reset') def request_go(self, event=None): self.request_mode('go') def request_fail(self, event=None): self.request_mode('fail') def request_mode(self, request): """ Sets the request mode (stop, reset, go, fail). The request is read by gps_main before sampling, and the appropriate action is taken. """ self.mode = 'request' self.request = request self.set_action_text(self.request + ' requested') self.set_action_bgcolor(self._colors[self.request], alpha=0.2) def process_mode(self): """ Completes the current request, after it is first read by gps_main. Displays visual confirmation that the request was processed, displays any error messages, and then switches into mode 'run' or 'wait'. """ self.mode = 'process' self.set_action_text(self.request + ' processed') self.set_action_bgcolor(self._colors[self.request], alpha=1.0) if self.err_msg: self.set_action_text(self.request + ' processed' + '\nERROR: ' + self.err_msg) self.err_msg = None time.sleep(1.0) else: time.sleep(0.5) if self.request in ('stop', 'reset', 'fail'): self.wait_mode() elif self.request == 'go': self.run_mode() self.request = None def wait_mode(self): self.mode = 'wait' self.set_action_text('waiting') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def run_mode(self): self.mode = 'run' self.set_action_text('running') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def end_mode(self): self.mode = 'end' self.set_action_text('ended') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def estop(self, event=None): self.set_action_text('estop: NOT IMPLEMENTED') # GUI functions def set_action_text(self, text): self._action_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_status_text(self, text): self._status_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_output_text(self, text): self._algthm_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def append_output_text(self, text): self._algthm_output.append_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def start_display_calculating(self): self._calculating_run.set() def stop_display_calculating(self): self._calculating_run.clear() def set_image_overlays(self, condition): """ Sets up the image visualizer with what images to overlay if "overlay_initial_image" or "overlay_target_image" is pressed. """ if not config['image_on'] or not self._target_filename: return initial_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'initial', 'image', default=None) target_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'target', 'image', default=None) self._image_visualizer.set_initial_image( initial_image, alpha=config['image_overlay_alpha']) self._image_visualizer.set_target_image( target_image, alpha=config['image_overlay_alpha']) # Iteration update functions def update(self, itr, algorithm, agent, traj_sample_lists, pol_sample_lists): """ After each iteration, update the iteration data output, the cost plot, and the 3D trajectory visualizations (if end effector points exist). """ if self._first_update: policy_titles = pol_sample_lists != None self._output_column_titles(algorithm, policy_titles) self._first_update = False costs = [ np.mean(np.sum(algorithm.prev[m].cs, axis=1)) for m in range(algorithm.M) ] if algorithm._hyperparams['ioc']: gt_costs = [ np.mean(np.sum(algorithm.prev[m].cgt, axis=1)) for m in range(algorithm.M) ] self._update_iteration_data(itr, algorithm, gt_costs, pol_sample_lists) self._gt_cost_plotter.update(gt_costs, t=itr) else: self._update_iteration_data(itr, algorithm, costs, pol_sample_lists) self._cost_plotter.update(costs, t=itr) if END_EFFECTOR_POINTS in agent.x_data_types: self._update_trajectory_visualizations(algorithm, agent, traj_sample_lists, pol_sample_lists) self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend # import pdb; pdb.set_trace() def _output_column_titles(self, algorithm, policy_titles=False): """ Setup iteration data column titles: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ self.set_output_text(self._hyperparams['experiment_name']) if policy_titles: condition_titles = '%3s | %8s %12s' % ('', '', '') itr_data_fields = '%3s | %8s %12s' % ('itr', 'avg_cost', 'avg_pol_cost') else: condition_titles = '%3s | %8s' % ('', '') itr_data_fields = '%3s | %8s' % ('itr', 'avg_cost') for m in range(algorithm.M): condition_titles += ' | %8s %9s %-7d' % ('', 'condition', m) itr_data_fields += ' | %8s %8s %8s' % (' cost ', ' step ', 'entropy ') if algorithm.prev[0].pol_info is not None: condition_titles += ' %8s %8s' % ('', '') itr_data_fields += ' %8s %8s' % ('kl_div_i', 'kl_div_f') if algorithm._hyperparams['ioc'] and not algorithm._hyperparams[ 'learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('kl_div') if algorithm._hyperparams['learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('mean_dist') if policy_titles: condition_titles += ' %8s %8s %8s' % ('', '', '') itr_data_fields += ' %8s %8s %8s' % ('pol_cost', 'kl_div_i', 'kl_div_f') self.append_output_text(condition_titles) self.append_output_text(itr_data_fields) def _update_iteration_data(self, itr, algorithm, costs, pol_sample_lists): """ Update iteration data information: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ avg_cost = np.mean(costs) if pol_sample_lists is not None: pol_costs = [np.mean([np.sum(algorithm.cost[m].eval(s)[0]) \ for s in pol_sample_lists[m]]) \ for m in range(algorithm.M)] itr_data = '%3d | %8.2f %12.2f' % (itr, avg_cost, np.mean(pol_costs)) else: itr_data = '%3d | %8.2f' % (itr, avg_cost) for m in range(algorithm.M): cost = costs[m] step = algorithm.prev[m].step_mult * algorithm.base_kl_step entropy = 2 * np.sum( np.log( np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar, axis1=1, axis2=2))) itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy) if algorithm.prev[0].pol_info is not None: kl_div_i = algorithm.prev[m].pol_info.prev_kl[0] kl_div_f = algorithm.prev[m].pol_info.prev_kl[-1] itr_data += ' %8.2f %8.2f' % (kl_div_i, kl_div_f) if algorithm._hyperparams['ioc'] and not algorithm._hyperparams[ 'learning_from_prior']: itr_data += ' %8.2f' % (algorithm.kl_div[itr][m]) if algorithm._hyperparams['learning_from_prior']: itr_data += ' %8.2f' % (algorithm.dists_to_target[itr][m]) if pol_sample_lists is not None: kl_div_i = algorithm.cur[m].pol_info.init_kl.mean() kl_div_f = algorithm.cur[m].pol_info.prev_kl.mean() itr_data += ' %8.2f %8.2f %8.2f' % (pol_costs[m], kl_div_i, kl_div_f) self.append_output_text(itr_data) def _update_trajectory_visualizations(self, algorithm, agent, traj_sample_lists, pol_sample_lists): """ Update 3D trajectory visualizations information: the trajectory samples, policy samples, and linear Gaussian controller means and covariances. """ xlim, ylim, zlim = self._calculate_3d_axis_limits( traj_sample_lists, pol_sample_lists) for m in range(algorithm.M): self._traj_visualizer.clear(m) self._traj_visualizer.set_lim(i=m, xlim=xlim, ylim=ylim, zlim=zlim) self._update_samples_plots(traj_sample_lists, m, 'green', 'Trajectory Samples') self._update_linear_gaussian_controller_plots(algorithm, agent, m) if pol_sample_lists: self._update_samples_plots(pol_sample_lists, m, 'blue', 'Policy Samples') self._traj_visualizer.draw() # this must be called explicitly def _calculate_3d_axis_limits(self, traj_sample_lists, pol_sample_lists): """ Calculate the 3D axis limits shared between trajectory plots, based on the minimum and maximum xyz values across all samples. """ all_eept = np.empty((0, 3)) sample_lists = traj_sample_lists if pol_sample_lists: sample_lists += traj_sample_lists for sample_list in sample_lists: for sample in sample_list.get_samples(): ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1] / 3): ee_pt_i = ee_pt[:, 3 * i + 0:3 * i + 3] all_eept = np.r_[all_eept, ee_pt_i] min_xyz = np.amin(all_eept, axis=0) max_xyz = np.amax(all_eept, axis=0) xlim = buffered_axis_limits(min_xyz[0], max_xyz[0], buffer_factor=1.25) ylim = buffered_axis_limits(min_xyz[1], max_xyz[1], buffer_factor=1.25) zlim = buffered_axis_limits(min_xyz[2], max_xyz[2], buffer_factor=1.25) return xlim, ylim, zlim def _update_linear_gaussian_controller_plots(self, algorithm, agent, m): """ Update the linear Guassian controller plots with iteration data, for the mean and covariances of the end effector points. """ # Calculate mean and covariance for end effector points eept_idx = agent.get_idx_x(END_EFFECTOR_POINTS) start, end = eept_idx[0], eept_idx[-1] mu, sigma = algorithm.traj_opt.forward(algorithm.prev[m].traj_distr, algorithm.prev[m].traj_info) mu_eept, sigma_eept = mu[:, start:end + 1], sigma[:, start:end + 1, start:end + 1] # Linear Gaussian Controller Distributions (Red) for i in range(mu_eept.shape[1] / 3): mu, sigma = mu_eept[:, 3 * i + 0:3 * i + 3], sigma_eept[:, 3 * i + 0:3 * i + 3, 3 * i + 0:3 * i + 3] self._traj_visualizer.plot_3d_gaussian( i=m, mu=mu, sigma=sigma, edges=100, linestyle='-', linewidth=1.0, color='red', alpha=0.15, label='LG Controller Distributions') # Linear Gaussian Controller Means (Dark Red) for i in range(mu_eept.shape[1] / 3): mu = mu_eept[:, 3 * i + 0:3 * i + 3] self._traj_visualizer.plot_3d_points(i=m, points=mu, linestyle='None', marker='x', markersize=5.0, markeredgewidth=1.0, color=(0.5, 0, 0), alpha=1.0, label='LG Controller Means') def _update_samples_plots(self, sample_lists, m, color, label): """ Update the samples plots with iteration data, for the trajectory samples and the policy samples. """ samples = sample_lists[m].get_samples() for sample in samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1] / 3): ee_pt_i = ee_pt[:, 3 * i + 0:3 * i + 3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color=color, label=label) def save_figure(self, filename): self._fig.savefig(filename)
def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config[ 'initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox( self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D( self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer( self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread( target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start()
class GPSTrainingGUI(object): """ GPS Training GUI class. """ def __init__(self, hyperparams): self._hyperparams = copy.deepcopy(common_config) self._hyperparams.update(copy.deepcopy(gps_training_config)) self._hyperparams.update(hyperparams) self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = '' # GPS Training Status. self.mode = 'run' # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] self._actions = {action._key: action for action in actions_arr} for key, action in self._actions.iteritems(): if key in self._hyperparams['keyboard_bindings']: action._kb = self._hyperparams['keyboard_bindings'][key] if key in self._hyperparams['ps3_bindings']: action._pb = self._hyperparams['ps3_bindings'][key] # GUI Components. plt.ion() plt.rcParams['toolbar'] = 'None' # Remove 's' keyboard shortcut for saving. plt.rcParams['keymap.save'] = '' self._fig = plt.figure(figsize=(12, 12)) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_axis = self._gs[0:2, 0:8] self._gs_action_output = self._gs[2:3, 0:4] self._gs_status_output = self._gs[3:4, 0:4] self._gs_cost_plotter = self._gs[2:4, 4:8] self._gs_algthm_output = self._gs[4:8, 0:8] self._gs_traj_visualizer = self._gs[8:16, 0:4] self._gs_image_visualizer = self._gs[8:16, 4:8] # Create GUI components. self._action_axis = ActionAxis(self._fig, self._gs_action_axis, 1, 4, self._actions, ps3_process_rate=self._hyperparams['ps3_process_rate'], ps3_topic=self._hyperparams['ps3_topic'], ps3_button=self._hyperparams['ps3_button'], inverted_ps3_button=self._hyperparams['inverted_ps3_button']) self._action_output = OutputAxis(self._fig, self._gs_action_output, border_on=True) self._status_output = OutputAxis(self._fig, self._gs_status_output, border_on=False) self._algthm_output = OutputAxis(self._fig, self._gs_algthm_output, max_display_size=15, log_filename=self._log_filename, fontsize=10, font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=(240, 240), rostopic=self._hyperparams['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) self.run_mode() # WARNING: Make sure the legend values in UPDATE match the below linestyles/markers and colors [self._traj_visualizer.set_title(m, 'Condition %d' % (m)) for m in range(self._hyperparams['conditions'])] self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # GPS Training Functions. #TODO: Docstrings here. def request_stop(self, event=None): self.request_mode('stop') def request_reset(self, event=None): self.request_mode('reset') def request_go(self, event=None): self.request_mode('go') def request_fail(self, event=None): self.request_mode('fail') def request_mode(self, request): self.mode = 'request' self.request = request self.set_action_text(self.request + ' requested') self.set_action_bgcolor(self._colors[self.request], alpha=0.2) def process_mode(self): self.mode = 'process' self.set_action_text(self.request + ' processed') self.set_action_bgcolor(self._colors[self.request], alpha=1.0) if self.err_msg: self.set_action_text(self.request + ' processed' + '\nERROR: ' + self.err_msg) self.err_msg = None time.sleep(1.0) else: time.sleep(0.5) if self.request in ('stop', 'reset', 'fail'): self.wait_mode() elif self.request == 'go': self.run_mode() self.request = None def wait_mode(self): self.mode = 'wait' self.set_action_text('waiting') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def run_mode(self): self.mode = 'run' self.set_action_text('running') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def end_mode(self): self.mode = 'end' self.set_action_text('ended') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def estop(self, event=None): self.set_action_text('estop: NOT IMPLEMENTED') # GUI functions. def set_action_text(self, text): self._action_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_status_text(self, text): self._status_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_output_text(self, text): self._algthm_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def append_output_text(self, text): self._algthm_output.append_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_image_overlays(self, condition): if len(self._target_filename) == 0: return initial_image = load_data_from_npz(self._target_filename, self._hyperparams['image_actuator'], str(condition), 'initial', 'image', default=np.zeros((1,1,3))) target_image = load_data_from_npz(self._target_filename, self._hyperparams['image_actuator'], str(condition), 'target', 'image', default=np.zeros((1,1,3))) self._image_visualizer.set_initial_image(initial_image, alpha=0.3) self._image_visualizer.set_target_image(target_image, alpha=0.3) def update(self, itr, algorithm, agent, traj_sample_lists, pol_sample_lists): # Plot Costs if algorithm.M == 1: # Update plot with each sample's cost (summed over time). costs = np.sum(algorithm.prev[0].cs, axis=1) else: # Update plot with each condition's mean sample cost (summed over time). costs = [np.mean(np.sum(algorithm.prev[m].cs, axis=1)) for m in range(algorithm.M)] self._cost_plotter.update(costs, t=itr) # Setup iteration data column titles and 3D visualization plot titles and legend if self._first_update: self.set_output_text(self._hyperparams['experiment_name']) condition_titles = '%3s | %8s' % ('', '') itr_data_fields = '%3s | %8s' % ('itr', 'avg_cost') for m in range(algorithm.M): condition_titles += ' | %8s %9s %-7d' % ('', 'condition', m) itr_data_fields += ' | %8s %8s %8s' % (' cost ', ' step ', 'entropy ') if algorithm.prev[0].pol_info is not None: condition_titles += ' %8s %8s' % ('', '') itr_data_fields += ' %8s %8s' % ('kl_div_i', 'kl_div_f') self.append_output_text(condition_titles) self.append_output_text(itr_data_fields) self._first_update = False # Print Iteration Data avg_cost = np.mean(costs) itr_data = '%3d | %8.2f' % (itr, avg_cost) for m in range(algorithm.M): cost = costs[m] step = algorithm.prev[m].step_mult entropy = 2*np.sum(np.log(np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar, axis1=1, axis2=2))) itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy) if algorithm.prev[0].pol_info is not None: kl_div_i = algorithm.prev[m].pol_info.prev_kl[0] kl_div_f = algorithm.prev[m].pol_info.prev_kl[-1] itr_data += ' %8.2f %8.2f' % (kl_div_i, kl_div_f) self.append_output_text(itr_data) if END_EFFECTOR_POINTS not in agent.x_data_types: # Skip plotting samples. self._traj_visualizer.draw() # this must be called explicitly self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend return # TODO(xinyutan) - this assumes that END_EFFECTOR_POINTS are in the # sample, which is not true for box2d. quick fix is above. # Calculate xlim, ylim, zlim for 3D visualizations from traj_sample_lists and pol_sample_lists # (this clips off LQG means/distributions that are not in the area of interest) all_eept = np.empty((0, 3)) sample_lists = traj_sample_lists + pol_sample_lists if pol_sample_lists else traj_sample_lists for sample_list in sample_lists: for sample in sample_list.get_samples(): ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] all_eept = np.r_[all_eept, ee_pt_i] min_xyz = np.amin(all_eept, axis=0) max_xyz = np.amax(all_eept, axis=0) xlim, ylim, zlim = (min_xyz[0], max_xyz[0]), (min_xyz[1], max_xyz[1]), (min_xyz[2], max_xyz[2]) # Plot 3D Visualizations for m in range(algorithm.M): # Clear previous plots self._traj_visualizer.clear(m) self._traj_visualizer.set_lim(i=m, xlim=xlim, ylim=ylim, zlim=zlim) # Linear Gaussian Controller Distributions (Red) mu, sigma = algorithm.traj_opt.forward(algorithm.prev[m].traj_distr, algorithm.prev[m].traj_info) eept_idx = agent.get_idx_x(END_EFFECTOR_POINTS) start, end = eept_idx[0], eept_idx[-1] mu_eept, sigma_eept = mu[:, start:end+1], sigma[:, start:end+1, start:end+1] for i in range(mu_eept.shape[1]/3): mu, sigma = mu_eept[:, 3*i+0:3*i+3], sigma_eept[:, 3*i+0:3*i+3, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_gaussian(i=m, mu=mu, sigma=sigma, edges=100, linestyle='-', linewidth=1.0, color='red', alpha=0.15, label='LG Controller Distributions') # Linear Gaussian Controller Means (Dark Red) for i in range(mu_eept.shape[1]/3): mu = mu_eept[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(i=m, points=mu, linestyle='None', marker='x', markersize=5.0, markeredgewidth=1.0, color=(0.5, 0, 0), alpha=1.0, label='LG Controller Means') # Trajectory Samples (Green) traj_samples = traj_sample_lists[m].get_samples() for sample in traj_samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color='green', label='Trajectory Samples') # Policy Samples (Blue) if pol_sample_lists is not None: pol_samples = pol_sample_lists[m].get_samples() for sample in pol_samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color='blue', label='Policy Samples') self._traj_visualizer.draw() # this must be called explicitly self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend def save_figure(self, filename): self._fig.savefig(filename)
def __init__(self, hyperparams): self._hyperparams = copy.deepcopy(common_config) self._hyperparams.update(copy.deepcopy(gps_training_config)) self._hyperparams.update(hyperparams) self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = '' # GPS Training Status. self.mode = 'run' # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] self._actions = {action._key: action for action in actions_arr} for key, action in self._actions.iteritems(): if key in self._hyperparams['keyboard_bindings']: action._kb = self._hyperparams['keyboard_bindings'][key] if key in self._hyperparams['ps3_bindings']: action._pb = self._hyperparams['ps3_bindings'][key] # GUI Components. plt.ion() plt.rcParams['toolbar'] = 'None' # Remove 's' keyboard shortcut for saving. plt.rcParams['keymap.save'] = '' self._fig = plt.figure(figsize=(12, 12)) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_axis = self._gs[0:2, 0:8] self._gs_action_output = self._gs[2:3, 0:4] self._gs_status_output = self._gs[3:4, 0:4] self._gs_cost_plotter = self._gs[2:4, 4:8] self._gs_algthm_output = self._gs[4:8, 0:8] self._gs_traj_visualizer = self._gs[8:16, 0:4] self._gs_image_visualizer = self._gs[8:16, 4:8] # Create GUI components. self._action_axis = ActionAxis(self._fig, self._gs_action_axis, 1, 4, self._actions, ps3_process_rate=self._hyperparams['ps3_process_rate'], ps3_topic=self._hyperparams['ps3_topic'], ps3_button=self._hyperparams['ps3_button'], inverted_ps3_button=self._hyperparams['inverted_ps3_button']) self._action_output = OutputAxis(self._fig, self._gs_action_output, border_on=True) self._status_output = OutputAxis(self._fig, self._gs_status_output, border_on=False) self._algthm_output = OutputAxis(self._fig, self._gs_algthm_output, max_display_size=15, log_filename=self._log_filename, fontsize=10, font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=(240, 240), rostopic=self._hyperparams['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) self.run_mode() # WARNING: Make sure the legend values in UPDATE match the below linestyles/markers and colors [self._traj_visualizer.set_title(m, 'Condition %d' % (m)) for m in range(self._hyperparams['conditions'])] self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw()
def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config['initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox(self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread(target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start()
class GPSTrainingGUI(object): def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config['initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox(self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread(target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start() # GPS Training functions def request_stop(self, event=None): self.request_mode('stop') def request_reset(self, event=None): self.request_mode('reset') def request_go(self, event=None): self.request_mode('go') def request_fail(self, event=None): self.request_mode('fail') def request_mode(self, request): """ Sets the request mode (stop, reset, go, fail). The request is read by gps_main before sampling, and the appropriate action is taken. """ self.mode = 'request' self.request = request self.set_action_text(self.request + ' requested') self.set_action_bgcolor(self._colors[self.request], alpha=0.2) def process_mode(self): """ Completes the current request, after it is first read by gps_main. Displays visual confirmation that the request was processed, displays any error messages, and then switches into mode 'run' or 'wait'. """ self.mode = 'process' self.set_action_text(self.request + ' processed') self.set_action_bgcolor(self._colors[self.request], alpha=1.0) if self.err_msg: self.set_action_text(self.request + ' processed' + '\nERROR: ' + self.err_msg) self.err_msg = None time.sleep(1.0) else: time.sleep(0.5) if self.request in ('stop', 'reset', 'fail'): self.wait_mode() elif self.request == 'go': self.run_mode() self.request = None def wait_mode(self): self.mode = 'wait' self.set_action_text('waiting') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def run_mode(self): self.mode = 'run' self.set_action_text('running') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def end_mode(self): self.mode = 'end' self.set_action_text('ended') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def estop(self, event=None): self.set_action_text('estop: NOT IMPLEMENTED') # GUI functions def set_action_text(self, text): self._action_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_status_text(self, text): self._status_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_output_text(self, text): self._algthm_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def append_output_text(self, text): self._algthm_output.append_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def start_display_calculating(self): self._calculating_run.set() def stop_display_calculating(self): self._calculating_run.clear() def set_image_overlays(self, condition): """ Sets up the image visualizer with what images to overlay if "overlay_initial_image" or "overlay_target_image" is pressed. """ if not config['image_on'] or not self._target_filename: return initial_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'initial', 'image', default=None) target_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'target', 'image', default=None) self._image_visualizer.set_initial_image(initial_image, alpha=config['image_overlay_alpha']) self._image_visualizer.set_target_image(target_image, alpha=config['image_overlay_alpha']) # Iteration update functions def update(self, itr, algorithm, agent, traj_sample_lists, pol_sample_lists): """ After each iteration, update the iteration data output, the cost plot, and the 3D trajectory visualizations (if end effector points exist). """ if self._first_update: policy_titles = pol_sample_lists != None self._output_column_titles(algorithm, policy_titles) self._first_update = False costs = [np.mean(np.sum(algorithm.prev[m].cs, axis=1)) for m in range(algorithm.M)] if algorithm._hyperparams['ioc']: gt_costs = [np.mean(np.sum(algorithm.prev[m].cgt, axis=1)) for m in range(algorithm.M)] self._update_iteration_data(itr, algorithm, gt_costs, pol_sample_lists) self._gt_cost_plotter.update(gt_costs, t=itr) else: self._update_iteration_data(itr, algorithm, costs, pol_sample_lists) self._cost_plotter.update(costs, t=itr) if END_EFFECTOR_POINTS in agent.x_data_types: self._update_trajectory_visualizations(algorithm, agent, traj_sample_lists, pol_sample_lists) self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend # import pdb; pdb.set_trace() def _output_column_titles(self, algorithm, policy_titles=False): """ Setup iteration data column titles: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ self.set_output_text(self._hyperparams['experiment_name']) if policy_titles: condition_titles = '%3s | %8s %12s' % ('', '', '') itr_data_fields = '%3s | %8s %12s' % ('itr', 'avg_cost', 'avg_pol_cost') else: condition_titles = '%3s | %8s' % ('', '') itr_data_fields = '%3s | %8s' % ('itr', 'avg_cost') for m in range(algorithm.M): condition_titles += ' | %8s %9s %-7d' % ('', 'condition', m) itr_data_fields += ' | %8s %8s %8s' % (' cost ', ' step ', 'entropy ') if algorithm.prev[0].pol_info is not None: condition_titles += ' %8s %8s' % ('', '') itr_data_fields += ' %8s %8s' % ('kl_div_i', 'kl_div_f') if algorithm._hyperparams['ioc'] and not algorithm._hyperparams['learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('kl_div') if algorithm._hyperparams['learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('mean_dist') if policy_titles: condition_titles += ' %8s %8s %8s' % ('', '', '') itr_data_fields += ' %8s %8s %8s' % ('pol_cost', 'kl_div_i', 'kl_div_f') self.append_output_text(condition_titles) self.append_output_text(itr_data_fields) def _update_iteration_data(self, itr, algorithm, costs, pol_sample_lists): """ Update iteration data information: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ avg_cost = np.mean(costs) if pol_sample_lists is not None: pol_costs = [np.mean([np.sum(algorithm.cost[m].eval(s)[0]) \ for s in pol_sample_lists[m]]) \ for m in range(algorithm.M)] itr_data = '%3d | %8.2f %12.2f' % (itr, avg_cost, np.mean(pol_costs)) else: itr_data = '%3d | %8.2f' % (itr, avg_cost) for m in range(algorithm.M): cost = costs[m] step = algorithm.prev[m].step_mult * algorithm.base_kl_step entropy = 2*np.sum(np.log(np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar, axis1=1, axis2=2))) itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy) if algorithm.prev[0].pol_info is not None: kl_div_i = algorithm.prev[m].pol_info.prev_kl[0] kl_div_f = algorithm.prev[m].pol_info.prev_kl[-1] itr_data += ' %8.2f %8.2f' % (kl_div_i, kl_div_f) if algorithm._hyperparams['ioc'] and not algorithm._hyperparams['learning_from_prior']: itr_data += ' %8.2f' % (algorithm.kl_div[itr][m]) if algorithm._hyperparams['learning_from_prior']: itr_data += ' %8.2f' % (algorithm.dists_to_target[itr][m]) if pol_sample_lists is not None: kl_div_i = algorithm.cur[m].pol_info.init_kl.mean() kl_div_f = algorithm.cur[m].pol_info.prev_kl.mean() itr_data += ' %8.2f %8.2f %8.2f' % (pol_costs[m], kl_div_i, kl_div_f) self.append_output_text(itr_data) def _update_trajectory_visualizations(self, algorithm, agent, traj_sample_lists, pol_sample_lists): """ Update 3D trajectory visualizations information: the trajectory samples, policy samples, and linear Gaussian controller means and covariances. """ xlim, ylim, zlim = self._calculate_3d_axis_limits(traj_sample_lists, pol_sample_lists) for m in range(algorithm.M): self._traj_visualizer.clear(m) self._traj_visualizer.set_lim(i=m, xlim=xlim, ylim=ylim, zlim=zlim) self._update_samples_plots(traj_sample_lists, m, 'green', 'Trajectory Samples') self._update_linear_gaussian_controller_plots(algorithm, agent, m) if pol_sample_lists: self._update_samples_plots(pol_sample_lists, m, 'blue', 'Policy Samples') self._traj_visualizer.draw() # this must be called explicitly def _calculate_3d_axis_limits(self, traj_sample_lists, pol_sample_lists): """ Calculate the 3D axis limits shared between trajectory plots, based on the minimum and maximum xyz values across all samples. """ all_eept = np.empty((0, 3)) sample_lists = traj_sample_lists if pol_sample_lists: sample_lists += traj_sample_lists for sample_list in sample_lists: for sample in sample_list.get_samples(): ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] all_eept = np.r_[all_eept, ee_pt_i] min_xyz = np.amin(all_eept, axis=0) max_xyz = np.amax(all_eept, axis=0) xlim = buffered_axis_limits(min_xyz[0], max_xyz[0], buffer_factor=1.25) ylim = buffered_axis_limits(min_xyz[1], max_xyz[1], buffer_factor=1.25) zlim = buffered_axis_limits(min_xyz[2], max_xyz[2], buffer_factor=1.25) return xlim, ylim, zlim def _update_linear_gaussian_controller_plots(self, algorithm, agent, m): """ Update the linear Guassian controller plots with iteration data, for the mean and covariances of the end effector points. """ # Calculate mean and covariance for end effector points eept_idx = agent.get_idx_x(END_EFFECTOR_POINTS) start, end = eept_idx[0], eept_idx[-1] mu, sigma = algorithm.traj_opt.forward(algorithm.prev[m].traj_distr, algorithm.prev[m].traj_info) mu_eept, sigma_eept = mu[:, start:end+1], sigma[:, start:end+1, start:end+1] # Linear Gaussian Controller Distributions (Red) for i in range(mu_eept.shape[1]/3): mu, sigma = mu_eept[:, 3*i+0:3*i+3], sigma_eept[:, 3*i+0:3*i+3, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_gaussian(i=m, mu=mu, sigma=sigma, edges=100, linestyle='-', linewidth=1.0, color='red', alpha=0.15, label='LG Controller Distributions') # Linear Gaussian Controller Means (Dark Red) for i in range(mu_eept.shape[1]/3): mu = mu_eept[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(i=m, points=mu, linestyle='None', marker='x', markersize=5.0, markeredgewidth=1.0, color=(0.5, 0, 0), alpha=1.0, label='LG Controller Means') def _update_samples_plots(self, sample_lists, m, color, label): """ Update the samples plots with iteration data, for the trajectory samples and the policy samples. """ samples = sample_lists[m].get_samples() for sample in samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color=color, label=label) def save_figure(self, filename): self._fig.savefig(filename)
def __init__(self, hyperparams, agent): self._agent = agent self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config['initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', } self._actuator_types = config['actuator_types'] self._actuator_names = config['actuator_names'] self._first_update = True self._actuator_number = 0 self._actuator_type = self._actuator_types[self._actuator_number] self._initial_position = ('unknown', 'unknown', 'unknown') self._target_position = ('unknown', 'unknown', 'unknown') # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('GCM go', 'go', self.request_go, axis_pos=2), Action('transfer learning', 'transfer_learning', self.request_tl, axis_pos=3), Action('set initial state', 'initstate', self.request_init_state, axis_pos=4), Action('set goal state', 'goalstate', self.request_goal_state, axis_pos=5), Action('test transfer learning', 'test_tl', self.request_test_tl, axis_pos=6), Action('generalize', 'generalize', self.request_generalize, axis_pos=7), Action('mti', 'move_to_initial', self.move_to_initial, axis_pos=8), Action('mtt', 'move_to_target', self.move_to_target, axis_pos=9), Action('rc', 'relax_controller', self.relax_controller, axis_pos=10), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(18, 8) self._gs_action_panel = self._gs[0:4, 0:8] self._gs_action_output = self._gs[4:5, 0:4] self._gs_status_output = self._gs[5:6, 0:4] self._gs_cost_plotter = self._gs[4:10, 4:8] self._gs_algthm_output = self._gs[6:10, 0:4] self._gs_traj_visualizer = self._gs[10:18, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 3, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox(self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread(target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start()
class TransferLearningGUI(object): def __init__(self, hyperparams, agent): self._agent = agent self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config['initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', } self._actuator_types = config['actuator_types'] self._actuator_names = config['actuator_names'] self._first_update = True self._actuator_number = 0 self._actuator_type = self._actuator_types[self._actuator_number] self._initial_position = ('unknown', 'unknown', 'unknown') self._target_position = ('unknown', 'unknown', 'unknown') # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('GCM go', 'go', self.request_go, axis_pos=2), Action('transfer learning', 'transfer_learning', self.request_tl, axis_pos=3), Action('set initial state', 'initstate', self.request_init_state, axis_pos=4), Action('set goal state', 'goalstate', self.request_goal_state, axis_pos=5), Action('test transfer learning', 'test_tl', self.request_test_tl, axis_pos=6), Action('generalize', 'generalize', self.request_generalize, axis_pos=7), Action('mti', 'move_to_initial', self.move_to_initial, axis_pos=8), Action('mtt', 'move_to_target', self.move_to_target, axis_pos=9), Action('rc', 'relax_controller', self.relax_controller, axis_pos=10), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(18, 8) self._gs_action_panel = self._gs[0:4, 0:8] self._gs_action_output = self._gs[4:5, 0:4] self._gs_status_output = self._gs[5:6, 0:4] self._gs_cost_plotter = self._gs[4:10, 4:8] self._gs_algthm_output = self._gs[6:10, 0:4] self._gs_traj_visualizer = self._gs[10:18, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 3, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox(self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread(target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start() # GPS Training functions def request_stop(self, event=None): self.request_mode('stop') self._agent.sample( self._agent.policy, self._agent.condition, verbose=None, save=False, noisy=False, use_TfController=True, timeout=0, reset=False) def request_reset(self, event=None): self.request_mode('reset') self._agent.reset(self._agent.condition) def request_go(self, event=None): self.request_mode('go') def request_tl(self, event=None): self.request_mode('transfer_learning') def request_test_tl(self, event=None): self.request_mode('test_tl') def request_generalize(self, event=None): self.request_mode('generalize') def move_to_initial(self, event=None): ja = self._initial_position[0] self.set_action_status_message('move_to_initial', 'requested') self._agent.reset_arm(self._actuator_type, JOINT_SPACE, ja.T) self.set_action_status_message('move_to_initial', 'completed', message='initial position: %s' % str(ja)) def move_to_target(self, event=None): ja = self._target_position[0] self.set_action_status_message('move_to_target', 'requested') self._agent.reset_arm(self._actuator_type, JOINT_SPACE, ja.T) self.set_action_status_message('move_to_target', 'completed', message='target position: %s' % str(ja)) def relax_controller(self, event=None): self.set_action_status_message('relax_controller', 'requested') self._agent.relax_arm(self._actuator_type) self.set_action_status_message('relax_controller', 'completed', message='actuator name: %s' % self._actuator_name) def request_goal_state(self, event=None): self.request_mode('goalstate') sample = self._agent.get_data(arm=self._actuator_type) ja = sample.get(JOINT_ANGLES) ee_pos = sample.get(END_EFFECTOR_POSITIONS) ee_rot = sample.get(END_EFFECTOR_ROTATIONS) ee_tgt = np.ndarray.flatten( get_ee_points(self._agent._hyperparams['ee_points'], ee_pos, ee_rot).T ) self._agent._hyperparams['ee_points_tgt'] = [ee_tgt] self._target_position = (ja, ee_pos, ee_rot) self._agent._target_ja = [ja] def request_init_state(self, event=None): self.request_mode('initstate') sample = self._agent.get_data(arm=self._actuator_type) ja = sample.get(JOINT_ANGLES) ee_pos = sample.get(END_EFFECTOR_POSITIONS) ee_rot = sample.get(END_EFFECTOR_ROTATIONS) self._initial_position = (ja, ee_pos, ee_rot) reset_condition = { TRIAL_ARM: { 'mode': JOINT_SPACE, 'data': ja[0], }, } res_cons = [] res_cons.append(reset_condition) self._agent._initial_ja = [ja] def request_mode(self, request): """ Sets the request mode (stop, reset, go, fail). The request is read by gps_main before sampling, and the appropriate action is taken. """ self.mode = 'request' self.request = request self.set_action_text(self.request + ' requested') #self.set_action_bgcolor(self._colors[self.request], alpha=0.2) def process_mode(self): """ Completes the current request, after it is first read by gps_main. Displays visual confirmation that the request was processed, displays any error messages, and then switches into mode 'run' or 'wait'. """ self.mode = 'process' self.set_action_text(self.request + ' processed') self.set_action_bgcolor(self._colors[self.request], alpha=1.0) if self.err_msg: self.set_action_text(self.request + ' processed' + '\nERROR: ' + self.err_msg) self.err_msg = None time.sleep(1.0) else: time.sleep(0.5) if self.request in ('stop', 'reset', 'fail'): self.wait_mode() elif self.request == 'go': self.run_mode() self.request = None def wait_mode(self): self.mode = 'wait' self.set_action_text('waiting') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def run_mode(self): self.mode = 'run' self.set_action_text('running') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def end_mode(self): self.mode = 'end' self.set_action_text('ended') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def estop(self, event=None): self.set_action_text('estop: NOT IMPLEMENTED') # GUI functions def set_action_text(self, text): self._action_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_status_text(self, text): self._status_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_output_text(self, text): self._algthm_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def append_output_text(self, text): self._algthm_output.append_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def start_display_calculating(self): self._calculating_run.set() def stop_display_calculating(self): self._calculating_run.clear() def set_action_status_message(self, action, status, message=None): text = action + ': ' + status if message: text += '\n\n' + message self.set_action_text(text) if status == 'requested': self.set_action_bgcolor('yellow') elif status == 'completed': self.set_action_bgcolor('green') elif status == 'failed': self.set_action_bgcolor('red') def set_image_overlays(self, condition): """ Sets up the image visualizer with what images to overlay if "overlay_initial_image" or "overlay_target_image" is pressed. """ if not config['image_on'] or not self._target_filename: return initial_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'initial', 'image', default=None) target_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'target', 'image', default=None) self._image_visualizer.set_initial_image(initial_image, alpha=config['image_overlay_alpha']) self._image_visualizer.set_target_image(target_image, alpha=config['image_overlay_alpha']) # Iteration update functions def update(self, itr, algorithm, agent, traj_sample_lists, pol_sample_lists): """ After each iteration, update the iteration data output, the cost plot, and the 3D trajectory visualizations (if end effector points exist). """ if self._first_update: self._output_column_titles(algorithm) self._first_update = False costs = [np.mean(np.sum(algorithm.prev[m].cs, axis=1)) for m in range(algorithm.M)] self._update_iteration_data(itr, algorithm, costs, pol_sample_lists) self._cost_plotter.update(costs, t=itr) if END_EFFECTOR_POINTS in agent.x_data_types: self._update_trajectory_visualizations(algorithm, agent, traj_sample_lists, pol_sample_lists) self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend def _output_column_titles(self, algorithm, policy_titles=False): """ Setup iteration data column titles: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ self.set_output_text(self._hyperparams['experiment_name']) condition_titles = '%3s | %8s' % ('', '') itr_data_fields = '%3s | %8s' % ('itr', 'avg_cost') for m in range(algorithm.M): condition_titles += ' | %8s %9s %-7d' % ('', 'condition', m) itr_data_fields += ' | %8s %8s %8s' % (' cost ', ' step ', 'entropy ') self.append_output_text(condition_titles) self.append_output_text(itr_data_fields) def _update_iteration_data(self, itr, algorithm, costs, pol_sample_lists): """ Update iteration data information: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ avg_cost = np.mean(costs) if pol_sample_lists is not None: test_idx = algorithm._hyperparams['test_conditions'] print("tet_idx: ", test_idx) # pol_sample_lists is a list of singletons samples = [sl[0] for sl in pol_sample_lists] print("samples len: ", len(samples)) print("algorithm.costs: ", len(algorithm.cost)) print("costs: ", algorithm.cost[-1].eval(samples[0])) #print("algorithm.cost.eval: ", len(algorithm.cost[-1].eval)) pol_costs = [np.sum(algorithm.cost[idx].eval(s)[0]) for s, idx in zip(samples, test_idx)] itr_data = '%3d | %8.2f %12.2f' % (itr, avg_cost, np.mean(pol_costs)) else: itr_data = '%3d | %8.2f' % (itr, avg_cost) for m in range(algorithm.M): cost = costs[m] step = algorithm.prev[m].step_mult * algorithm.base_kl_step entropy = 2*np.sum(np.log(np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar, axis1=1, axis2=2))) itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy) self.append_output_text(itr_data) def _update_trajectory_visualizations(self, algorithm, agent, traj_sample_lists, pol_sample_lists): """ Update 3D trajectory visualizations information: the trajectory samples, policy samples, and linear Gaussian controller means and covariances. """ xlim, ylim, zlim = self._calculate_3d_axis_limits(traj_sample_lists, pol_sample_lists) for m in range(algorithm.M): self._traj_visualizer.clear(m) self._traj_visualizer.set_lim(i=m, xlim=xlim, ylim=ylim, zlim=zlim) self._update_linear_gaussian_controller_plots(algorithm, agent, m) self._update_samples_plots(traj_sample_lists, m, 'green', 'Trajectory Samples') if pol_sample_lists: self._update_samples_plots(pol_sample_lists, m, 'blue', 'Policy Samples') self._traj_visualizer.draw() # this must be called explicitly def _calculate_3d_axis_limits(self, traj_sample_lists, pol_sample_lists): """ Calculate the 3D axis limits shared between trajectory plots, based on the minimum and maximum xyz values across all samples. """ all_eept = np.empty((0, 3)) sample_lists = traj_sample_lists if pol_sample_lists: sample_lists += traj_sample_lists for sample_list in sample_lists: for sample in sample_list.get_samples(): ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] all_eept = np.r_[all_eept, ee_pt_i] min_xyz = np.amin(all_eept, axis=0) max_xyz = np.amax(all_eept, axis=0) xlim = buffered_axis_limits(min_xyz[0], max_xyz[0], buffer_factor=1.25) ylim = buffered_axis_limits(min_xyz[1], max_xyz[1], buffer_factor=1.25) zlim = buffered_axis_limits(min_xyz[2], max_xyz[2], buffer_factor=1.25) return xlim, ylim, zlim def _update_linear_gaussian_controller_plots(self, algorithm, agent, m): """ Update the linear Guassian controller plots with iteration data, for the mean and covariances of the end effector points. """ # Calculate mean and covariance for end effector points eept_idx = agent.get_idx_x(END_EFFECTOR_POINTS) start, end = eept_idx[0], eept_idx[-1] mu, sigma = algorithm.forward(algorithm.prev[m].traj_distr, algorithm.prev[m].traj_info) mu_eept, sigma_eept = mu[:, start:end+1], sigma[:, start:end+1, start:end+1] # Linear Gaussian Controller Distributions (Red) for i in range(mu_eept.shape[1]/3): mu, sigma = mu_eept[:, 3*i+0:3*i+3], sigma_eept[:, 3*i+0:3*i+3, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_gaussian(i=m, mu=mu, sigma=sigma, edges=100, linestyle='-', linewidth=1.0, color='red', alpha=0.15, label='LG Controller Distributions') # Linear Gaussian Controller Means (Dark Red) for i in range(mu_eept.shape[1]/3): mu = mu_eept[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(i=m, points=mu, linestyle='None', marker='x', markersize=5.0, markeredgewidth=1.0, color=(0.5, 0, 0), alpha=1.0, label='LG Controller Means') def _update_samples_plots(self, sample_lists, m, color, label): """ Update the samples plots with iteration data, for the trajectory samples and the policy samples. """ samples = sample_lists[m].get_samples() for sample in samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color=color, label=label) def save_figure(self, filename): self._fig.savefig(filename)
realtime_plotter = RealtimePlotter(fig, gs[3], labels=['i', 'j', 'i+j', 'i-j', 'mean'], alphas=[0.15, 0.15, 0.15, 0.15, 1.0]) run_demo(demo_realtime_plotter) # Mean Plotter def demo_mean_plotter(): i, j = 0, 0 while True: i += random.randint(-10, 10) j += random.randint(-10, 10) data = [i, j, i + j, i - j] mean_plotter.update(data) time.sleep(1) mean_plotter = MeanPlotter(fig, gs[4]) run_demo(demo_mean_plotter) # Plotter 3D def demo_plotter_3d(): xyzs = np.zeros((3, 1)) while True: plotter_3d.clear_all() xyz = np.random.randint(-10, 10, size=3).reshape((3,1)) xyzs = np.append(xyzs, xyz, axis=1) xs, ys, zs = xyzs plotter_3d.plot(0, xs, ys, zs) plotter_3d.draw() # this must be called explicitly time.sleep(1) plotter_3d = Plotter3D(fig, gs[5], num_plots=1, rows=1, cols=1)