class GPSTrainingGUI(object): def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config[ 'initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox( self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D( self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer( self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread( target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start() # GPS Training functions def request_stop(self, event=None): self.request_mode('stop') def request_reset(self, event=None): self.request_mode('reset') def request_go(self, event=None): self.request_mode('go') def request_fail(self, event=None): self.request_mode('fail') def request_mode(self, request): """ Sets the request mode (stop, reset, go, fail). The request is read by gps_main before sampling, and the appropriate action is taken. """ self.mode = 'request' self.request = request self.set_action_text(self.request + ' requested') self.set_action_bgcolor(self._colors[self.request], alpha=0.2) def process_mode(self): """ Completes the current request, after it is first read by gps_main. Displays visual confirmation that the request was processed, displays any error messages, and then switches into mode 'run' or 'wait'. """ self.mode = 'process' self.set_action_text(self.request + ' processed') self.set_action_bgcolor(self._colors[self.request], alpha=1.0) if self.err_msg: self.set_action_text(self.request + ' processed' + '\nERROR: ' + self.err_msg) self.err_msg = None time.sleep(1.0) else: time.sleep(0.5) if self.request in ('stop', 'reset', 'fail'): self.wait_mode() elif self.request == 'go': self.run_mode() self.request = None def wait_mode(self): self.mode = 'wait' self.set_action_text('waiting') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def run_mode(self): self.mode = 'run' self.set_action_text('running') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def end_mode(self): self.mode = 'end' self.set_action_text('ended') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def estop(self, event=None): self.set_action_text('estop: NOT IMPLEMENTED') # GUI functions def set_action_text(self, text): self._action_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_status_text(self, text): self._status_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_output_text(self, text): self._algthm_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def append_output_text(self, text): self._algthm_output.append_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def start_display_calculating(self): self._calculating_run.set() def stop_display_calculating(self): self._calculating_run.clear() def set_image_overlays(self, condition): """ Sets up the image visualizer with what images to overlay if "overlay_initial_image" or "overlay_target_image" is pressed. """ if not config['image_on'] or not self._target_filename: return initial_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'initial', 'image', default=None) target_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'target', 'image', default=None) self._image_visualizer.set_initial_image( initial_image, alpha=config['image_overlay_alpha']) self._image_visualizer.set_target_image( target_image, alpha=config['image_overlay_alpha']) # Iteration update functions def update(self, itr, algorithm, agent, traj_sample_lists, pol_sample_lists): """ After each iteration, update the iteration data output, the cost plot, and the 3D trajectory visualizations (if end effector points exist). """ if self._first_update: policy_titles = pol_sample_lists != None self._output_column_titles(algorithm, policy_titles) self._first_update = False costs = [ np.mean(np.sum(algorithm.prev[m].cs, axis=1)) for m in range(algorithm.M) ] if algorithm._hyperparams['ioc']: gt_costs = [ np.mean(np.sum(algorithm.prev[m].cgt, axis=1)) for m in range(algorithm.M) ] self._update_iteration_data(itr, algorithm, gt_costs, pol_sample_lists) self._gt_cost_plotter.update(gt_costs, t=itr) else: self._update_iteration_data(itr, algorithm, costs, pol_sample_lists) self._cost_plotter.update(costs, t=itr) if END_EFFECTOR_POINTS in agent.x_data_types: self._update_trajectory_visualizations(algorithm, agent, traj_sample_lists, pol_sample_lists) self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend # import pdb; pdb.set_trace() def _output_column_titles(self, algorithm, policy_titles=False): """ Setup iteration data column titles: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ self.set_output_text(self._hyperparams['experiment_name']) if policy_titles: condition_titles = '%3s | %8s %12s' % ('', '', '') itr_data_fields = '%3s | %8s %12s' % ('itr', 'avg_cost', 'avg_pol_cost') else: condition_titles = '%3s | %8s' % ('', '') itr_data_fields = '%3s | %8s' % ('itr', 'avg_cost') for m in range(algorithm.M): condition_titles += ' | %8s %9s %-7d' % ('', 'condition', m) itr_data_fields += ' | %8s %8s %8s' % (' cost ', ' step ', 'entropy ') if algorithm.prev[0].pol_info is not None: condition_titles += ' %8s %8s' % ('', '') itr_data_fields += ' %8s %8s' % ('kl_div_i', 'kl_div_f') if algorithm._hyperparams['ioc'] and not algorithm._hyperparams[ 'learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('kl_div') if algorithm._hyperparams['learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('mean_dist') if policy_titles: condition_titles += ' %8s %8s %8s' % ('', '', '') itr_data_fields += ' %8s %8s %8s' % ('pol_cost', 'kl_div_i', 'kl_div_f') self.append_output_text(condition_titles) self.append_output_text(itr_data_fields) def _update_iteration_data(self, itr, algorithm, costs, pol_sample_lists): """ Update iteration data information: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ avg_cost = np.mean(costs) if pol_sample_lists is not None: pol_costs = [np.mean([np.sum(algorithm.cost[m].eval(s)[0]) \ for s in pol_sample_lists[m]]) \ for m in range(algorithm.M)] itr_data = '%3d | %8.2f %12.2f' % (itr, avg_cost, np.mean(pol_costs)) else: itr_data = '%3d | %8.2f' % (itr, avg_cost) for m in range(algorithm.M): cost = costs[m] step = algorithm.prev[m].step_mult * algorithm.base_kl_step entropy = 2 * np.sum( np.log( np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar, axis1=1, axis2=2))) itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy) if algorithm.prev[0].pol_info is not None: kl_div_i = algorithm.prev[m].pol_info.prev_kl[0] kl_div_f = algorithm.prev[m].pol_info.prev_kl[-1] itr_data += ' %8.2f %8.2f' % (kl_div_i, kl_div_f) if algorithm._hyperparams['ioc'] and not algorithm._hyperparams[ 'learning_from_prior']: itr_data += ' %8.2f' % (algorithm.kl_div[itr][m]) if algorithm._hyperparams['learning_from_prior']: itr_data += ' %8.2f' % (algorithm.dists_to_target[itr][m]) if pol_sample_lists is not None: kl_div_i = algorithm.cur[m].pol_info.init_kl.mean() kl_div_f = algorithm.cur[m].pol_info.prev_kl.mean() itr_data += ' %8.2f %8.2f %8.2f' % (pol_costs[m], kl_div_i, kl_div_f) self.append_output_text(itr_data) def _update_trajectory_visualizations(self, algorithm, agent, traj_sample_lists, pol_sample_lists): """ Update 3D trajectory visualizations information: the trajectory samples, policy samples, and linear Gaussian controller means and covariances. """ xlim, ylim, zlim = self._calculate_3d_axis_limits( traj_sample_lists, pol_sample_lists) for m in range(algorithm.M): self._traj_visualizer.clear(m) self._traj_visualizer.set_lim(i=m, xlim=xlim, ylim=ylim, zlim=zlim) self._update_samples_plots(traj_sample_lists, m, 'green', 'Trajectory Samples') self._update_linear_gaussian_controller_plots(algorithm, agent, m) if pol_sample_lists: self._update_samples_plots(pol_sample_lists, m, 'blue', 'Policy Samples') self._traj_visualizer.draw() # this must be called explicitly def _calculate_3d_axis_limits(self, traj_sample_lists, pol_sample_lists): """ Calculate the 3D axis limits shared between trajectory plots, based on the minimum and maximum xyz values across all samples. """ all_eept = np.empty((0, 3)) sample_lists = traj_sample_lists if pol_sample_lists: sample_lists += traj_sample_lists for sample_list in sample_lists: for sample in sample_list.get_samples(): ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1] / 3): ee_pt_i = ee_pt[:, 3 * i + 0:3 * i + 3] all_eept = np.r_[all_eept, ee_pt_i] min_xyz = np.amin(all_eept, axis=0) max_xyz = np.amax(all_eept, axis=0) xlim = buffered_axis_limits(min_xyz[0], max_xyz[0], buffer_factor=1.25) ylim = buffered_axis_limits(min_xyz[1], max_xyz[1], buffer_factor=1.25) zlim = buffered_axis_limits(min_xyz[2], max_xyz[2], buffer_factor=1.25) return xlim, ylim, zlim def _update_linear_gaussian_controller_plots(self, algorithm, agent, m): """ Update the linear Guassian controller plots with iteration data, for the mean and covariances of the end effector points. """ # Calculate mean and covariance for end effector points eept_idx = agent.get_idx_x(END_EFFECTOR_POINTS) start, end = eept_idx[0], eept_idx[-1] mu, sigma = algorithm.traj_opt.forward(algorithm.prev[m].traj_distr, algorithm.prev[m].traj_info) mu_eept, sigma_eept = mu[:, start:end + 1], sigma[:, start:end + 1, start:end + 1] # Linear Gaussian Controller Distributions (Red) for i in range(mu_eept.shape[1] / 3): mu, sigma = mu_eept[:, 3 * i + 0:3 * i + 3], sigma_eept[:, 3 * i + 0:3 * i + 3, 3 * i + 0:3 * i + 3] self._traj_visualizer.plot_3d_gaussian( i=m, mu=mu, sigma=sigma, edges=100, linestyle='-', linewidth=1.0, color='red', alpha=0.15, label='LG Controller Distributions') # Linear Gaussian Controller Means (Dark Red) for i in range(mu_eept.shape[1] / 3): mu = mu_eept[:, 3 * i + 0:3 * i + 3] self._traj_visualizer.plot_3d_points(i=m, points=mu, linestyle='None', marker='x', markersize=5.0, markeredgewidth=1.0, color=(0.5, 0, 0), alpha=1.0, label='LG Controller Means') def _update_samples_plots(self, sample_lists, m, color, label): """ Update the samples plots with iteration data, for the trajectory samples and the policy samples. """ samples = sample_lists[m].get_samples() for sample in samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1] / 3): ee_pt_i = ee_pt[:, 3 * i + 0:3 * i + 3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color=color, label=label) def save_figure(self, filename): self._fig.savefig(filename)
def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config[ 'initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox( self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D( self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer( self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread( target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start()
class TargetSetupGUI(object): def __init__(self, hyperparams, agent): self._hyperparams = hyperparams self._agent = agent self._log_filename = self._hyperparams['log_filename'] self._target_filename = self._hyperparams['target_filename'] self._num_targets = config['num_targets'] self._actuator_types = config['actuator_types'] self._actuator_names = config['actuator_names'] self._num_actuators = len(self._actuator_types) # Target Setup Status. self._target_number = 0 self._actuator_number = 0 self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self._initial_position = ('unknown', 'unknown', 'unknown') self._target_position = ('unknown', 'unknown', 'unknown') self._initial_image = None self._target_image = None self._mannequin_mode = False self._mm_process = None # Actions. actions_arr = [ Action('ptn', 'prev_target_number', self.prev_target_number, axis_pos=0), Action('ntn', 'next_target_number', self.next_target_number, axis_pos=1), Action('pat', 'prev_actuator_type', self.prev_actuator_type, axis_pos=2), Action('nat', 'next_actuator_type', self.next_actuator_type, axis_pos=3), Action('sip', 'set_initial_position', self.set_initial_position, axis_pos=4), Action('stp', 'set_target_position', self.set_target_position, axis_pos=5), Action('sii', 'set_initial_image', self.set_initial_image, axis_pos=6), Action('sti', 'set_target_image', self.set_target_image, axis_pos=7), Action('mti', 'move_to_initial', self.move_to_initial, axis_pos=8), Action('mtt', 'move_to_target', self.move_to_target, axis_pos=9), Action('rc', 'relax_controller', self.relax_controller, axis_pos=10), Action('mm', 'mannequin_mode', self.mannequin_mode, axis_pos=11), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(4, 4) self._gs_action_panel = self._gs[0:1, 0:4] if config['image_on']: self._gs_target_output = self._gs[1:3, 0:2] self._gs_initial_image_visualizer = self._gs[3:4, 0:1] self._gs_target_image_visualizer = self._gs[3:4, 1:2] self._gs_action_output = self._gs[1:2, 2:4] self._gs_image_visualizer = self._gs[2:4, 2:4] else: self._gs_target_output = self._gs[1:4, 0:2] self._gs_action_output = self._gs[1:4, 2:4] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 3, 4, actions_arr) self._target_output = Textbox( self._fig, self._gs_target_output, log_filename=self._log_filename, fontsize=config['target_output_fontsize']) self._action_output = Textbox(self._fig, self._gs_action_output) if config['image_on']: self._initial_image_visualizer = ImageVisualizer( self._fig, self._gs_initial_image_visualizer) self._target_image_visualizer = ImageVisualizer( self._fig, self._gs_target_image_visualizer) self._image_visualizer = ImageVisualizer( self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self.reload_positions() self.update_target_text() self.set_action_text('Press an action to begin.') self.set_action_bgcolor('white') self._fig.canvas.draw() # Target Setup Functions. def prev_target_number(self, event=None): self.set_action_status_message('prev_target_number', 'requested') self._target_number = (self._target_number - 1) % self._num_targets self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('prev_target_number', 'completed', message='target number = %d' % self._target_number) def next_target_number(self, event=None): self.set_action_status_message('next_target_number', 'requested') self._target_number = (self._target_number + 1) % self._num_targets self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('next_target_number', 'completed', message='target number = %d' % self._target_number) def prev_actuator_type(self, event=None): self.set_action_status_message('prev_actuator_type', 'requested') self._actuator_number = (self._actuator_number - 1) % self._num_actuators self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('prev_actuator_type', 'completed', message='actuator name = %s' % self._actuator_name) def next_actuator_type(self, event=None): self.set_action_status_message('next_actuator_type', 'requested') self._actuator_number = (self._actuator_number + 1) % self._num_actuators self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('next_actuator_type', 'completed', message='actuator name = %s' % self._actuator_name) def set_initial_position(self, event=None): self.set_action_status_message('set_initial_position', 'requested') try: sample = self._agent.get_data(arm=self._actuator_type) except TimeoutException: self.set_action_status_message( 'set_initial_position', 'failed', message='TimeoutException while retrieving sample') return ja = sample.get(JOINT_ANGLES) ee_pos = sample.get(END_EFFECTOR_POSITIONS) ee_rot = sample.get(END_EFFECTOR_ROTATIONS) self._initial_position = (ja, ee_pos, ee_rot) save_pose_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial', self._initial_position) self.update_target_text() self.set_action_status_message('set_initial_position', 'completed', message='initial position =\n %s' % \ self.position_to_str(self._initial_position)) def set_target_position(self, event=None): self.set_action_status_message('set_target_position', 'requested') try: sample = self._agent.get_data(arm=self._actuator_type) except TimeoutException: self.set_action_status_message( 'set_target_position', 'failed', message='TimeoutException while retrieving sample') return ja = sample.get(JOINT_ANGLES) ee_pos = sample.get(END_EFFECTOR_POSITIONS) ee_rot = sample.get(END_EFFECTOR_ROTATIONS) self._target_position = (ja, ee_pos, ee_rot) save_pose_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target', self._target_position) self.update_target_text() self.set_action_status_message('set_target_position', 'completed', message='target position =\n %s' % \ self.position_to_str(self._target_position)) def set_initial_image(self, event=None): self.set_action_status_message('set_initial_image', 'requested') self._initial_image = self._image_visualizer.get_current_image() if self._initial_image is None: self.set_action_status_message('set_initial_image', 'failed', message='no image available') return save_data_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial', 'image', self._initial_image) self.update_target_text() self.set_action_status_message('set_initial_image', 'completed', message='initial image =\n %s' % str(self._initial_image)) def set_target_image(self, event=None): self.set_action_status_message('set_target_image', 'requested') self._target_image = self._image_visualizer.get_current_image() if self._target_image is None: self.set_action_status_message('set_target_image', 'failed', message='no image available') return save_data_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target', 'image', self._target_image) self.update_target_text() self.set_action_status_message('set_target_image', 'completed', message='target image =\n %s' % str(self._target_image)) def move_to_initial(self, event=None): ja = self._initial_position[0] self.set_action_status_message('move_to_initial', 'requested') self._agent.reset_arm(self._actuator_type, JOINT_SPACE, ja.T) self.set_action_status_message('move_to_initial', 'completed', message='initial position: %s' % str(ja)) def move_to_target(self, event=None): ja = self._target_position[0] self.set_action_status_message('move_to_target', 'requested') self._agent.reset_arm(self._actuator_type, JOINT_SPACE, ja.T) self.set_action_status_message('move_to_target', 'completed', message='target position: %s' % str(ja)) def relax_controller(self, event=None): self.set_action_status_message('relax_controller', 'requested') self._agent.relax_arm(self._actuator_type) self.set_action_status_message('relax_controller', 'completed', message='actuator name: %s' % self._actuator_name) def mannequin_mode(self, event=None): """ Calls "roslaunch pr2_mannequin_mode pr2_mannequin_mode.launch" (only works for the PR2 robot). """ if not self._mannequin_mode: self.set_action_status_message('mannequin_mode', 'requested') subprocess.Popen([ 'rosrun', 'pr2_controller_manager', 'pr2_controller_manager', 'stop', 'GPSPR2Plugin' ], stdout=DEVNULL) self._mm_process = subprocess.Popen([ 'roslaunch', 'pr2_mannequin_mode', 'pr2_mannequin_mode.launch' ], stdout=DEVNULL) self._mannequin_mode = True self.set_action_status_message('mannequin_mode', 'completed', message='mannequin mode toggled on') else: self.set_action_status_message('mannequin_mode', 'requested') self._mm_process.send_signal(signal.SIGINT) subprocess.Popen([ 'rosrun', 'pr2_controller_manager', 'pr2_controller_manager', 'start', 'GPSPR2Plugin' ], stdout=DEVNULL) self._mannequin_mode = False self.set_action_status_message( 'mannequin_mode', 'completed', message='mannequin mode toggled off') # GUI functions. def update_target_text(self): np.set_printoptions(precision=3, suppress=True) text = ('target number = %s\n' % str(self._target_number) + 'actuator name = %s\n' % str(self._actuator_name) + '\ninitial position\n%s' % self.position_to_str(self._initial_position) + '\ntarget position\n%s' % self.position_to_str(self._target_position) + '\ninitial image (left) =\n%s\n' % str(self._initial_image) + '\ntarget image (right) =\n%s\n' % str(self._target_image)) self._target_output.set_text(text) if config['image_on']: self._initial_image_visualizer.update(self._initial_image) self._target_image_visualizer.update(self._target_image) self._image_visualizer.set_initial_image(self._initial_image, alpha=0.3) self._image_visualizer.set_target_image(self._target_image, alpha=0.3) def position_to_str(self, position): np.set_printoptions(precision=3, suppress=True) ja, ee_pos, ee_rot = position return ('joint angles =\n%s\n' % ja + 'end effector positions =\n%s\n' % ee_pos + 'end effector rotations =\n%s\n' % ee_rot) def set_action_status_message(self, action, status, message=None): text = action + ': ' + status if message: text += '\n\n' + message self.set_action_text(text) if status == 'requested': self.set_action_bgcolor('yellow') elif status == 'completed': self.set_action_bgcolor('green') elif status == 'failed': self.set_action_bgcolor('red') def set_action_text(self, text=''): self._action_output.set_text(text) def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) def reload_positions(self): """ Reloads the initial/target positions and images. This is called after the target number of actuator type have changed. """ self._initial_position = load_pose_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial') self._target_position = load_pose_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target') self._initial_image = load_data_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial', 'image', default=None) self._target_image = load_data_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target', 'image', default=None)
def __init__(self, hyperparams, agent): self._hyperparams = hyperparams self._agent = agent self._log_filename = self._hyperparams['log_filename'] self._target_filename = self._hyperparams['target_filename'] self._num_targets = config['num_targets'] self._actuator_types = config['actuator_types'] self._actuator_names = config['actuator_names'] self._num_actuators = len(self._actuator_types) # Target Setup Status. self._target_number = 0 self._actuator_number = 0 self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self._initial_position = ('unknown', 'unknown', 'unknown') self._target_position = ('unknown', 'unknown', 'unknown') self._initial_image = None self._target_image = None self._mannequin_mode = False self._mm_process = None # Actions. actions_arr = [ Action('ptn', 'prev_target_number', self.prev_target_number, axis_pos=0), Action('ntn', 'next_target_number', self.next_target_number, axis_pos=1), Action('pat', 'prev_actuator_type', self.prev_actuator_type, axis_pos=2), Action('nat', 'next_actuator_type', self.next_actuator_type, axis_pos=3), Action('sip', 'set_initial_position', self.set_initial_position, axis_pos=4), Action('stp', 'set_target_position', self.set_target_position, axis_pos=5), Action('sii', 'set_initial_image', self.set_initial_image, axis_pos=6), Action('sti', 'set_target_image', self.set_target_image, axis_pos=7), Action('mti', 'move_to_initial', self.move_to_initial, axis_pos=8), Action('mtt', 'move_to_target', self.move_to_target, axis_pos=9), Action('rc', 'relax_controller', self.relax_controller, axis_pos=10), Action('mm', 'mannequin_mode', self.mannequin_mode, axis_pos=11), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(4, 4) self._gs_action_panel = self._gs[0:1, 0:4] if config['image_on']: self._gs_target_output = self._gs[1:3, 0:2] self._gs_initial_image_visualizer = self._gs[3:4, 0:1] self._gs_target_image_visualizer = self._gs[3:4, 1:2] self._gs_action_output = self._gs[1:2, 2:4] self._gs_image_visualizer = self._gs[2:4, 2:4] else: self._gs_target_output = self._gs[1:4, 0:2] self._gs_action_output = self._gs[1:4, 2:4] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 3, 4, actions_arr) self._target_output = Textbox( self._fig, self._gs_target_output, log_filename=self._log_filename, fontsize=config['target_output_fontsize']) self._action_output = Textbox(self._fig, self._gs_action_output) if config['image_on']: self._initial_image_visualizer = ImageVisualizer( self._fig, self._gs_initial_image_visualizer) self._target_image_visualizer = ImageVisualizer( self._fig, self._gs_target_image_visualizer) self._image_visualizer = ImageVisualizer( self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self.reload_positions() self.update_target_text() self.set_action_text('Press an action to begin.') self.set_action_bgcolor('white') self._fig.canvas.draw()
class GPSTrainingGUI(object): """ GPS Training GUI class. """ def __init__(self, hyperparams): self._hyperparams = copy.deepcopy(common_config) self._hyperparams.update(copy.deepcopy(gps_training_config)) self._hyperparams.update(hyperparams) self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = '' # GPS Training Status. self.mode = 'run' # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] self._actions = {action._key: action for action in actions_arr} for key, action in self._actions.iteritems(): if key in self._hyperparams['keyboard_bindings']: action._kb = self._hyperparams['keyboard_bindings'][key] if key in self._hyperparams['ps3_bindings']: action._pb = self._hyperparams['ps3_bindings'][key] # GUI Components. plt.ion() plt.rcParams['toolbar'] = 'None' # Remove 's' keyboard shortcut for saving. plt.rcParams['keymap.save'] = '' self._fig = plt.figure(figsize=(12, 12)) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_axis = self._gs[0:2, 0:8] self._gs_action_output = self._gs[2:3, 0:4] self._gs_status_output = self._gs[3:4, 0:4] self._gs_cost_plotter = self._gs[2:4, 4:8] self._gs_algthm_output = self._gs[4:8, 0:8] self._gs_traj_visualizer = self._gs[8:16, 0:4] self._gs_image_visualizer = self._gs[8:16, 4:8] # Create GUI components. self._action_axis = ActionAxis(self._fig, self._gs_action_axis, 1, 4, self._actions, ps3_process_rate=self._hyperparams['ps3_process_rate'], ps3_topic=self._hyperparams['ps3_topic'], ps3_button=self._hyperparams['ps3_button'], inverted_ps3_button=self._hyperparams['inverted_ps3_button']) self._action_output = OutputAxis(self._fig, self._gs_action_output, border_on=True) self._status_output = OutputAxis(self._fig, self._gs_status_output, border_on=False) self._algthm_output = OutputAxis(self._fig, self._gs_algthm_output, max_display_size=15, log_filename=self._log_filename, fontsize=10, font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=(240, 240), rostopic=self._hyperparams['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) self.run_mode() # WARNING: Make sure the legend values in UPDATE match the below linestyles/markers and colors [self._traj_visualizer.set_title(m, 'Condition %d' % (m)) for m in range(self._hyperparams['conditions'])] self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # GPS Training Functions. #TODO: Docstrings here. def request_stop(self, event=None): self.request_mode('stop') def request_reset(self, event=None): self.request_mode('reset') def request_go(self, event=None): self.request_mode('go') def request_fail(self, event=None): self.request_mode('fail') def request_mode(self, request): self.mode = 'request' self.request = request self.set_action_text(self.request + ' requested') self.set_action_bgcolor(self._colors[self.request], alpha=0.2) def process_mode(self): self.mode = 'process' self.set_action_text(self.request + ' processed') self.set_action_bgcolor(self._colors[self.request], alpha=1.0) if self.err_msg: self.set_action_text(self.request + ' processed' + '\nERROR: ' + self.err_msg) self.err_msg = None time.sleep(1.0) else: time.sleep(0.5) if self.request in ('stop', 'reset', 'fail'): self.wait_mode() elif self.request == 'go': self.run_mode() self.request = None def wait_mode(self): self.mode = 'wait' self.set_action_text('waiting') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def run_mode(self): self.mode = 'run' self.set_action_text('running') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def end_mode(self): self.mode = 'end' self.set_action_text('ended') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def estop(self, event=None): self.set_action_text('estop: NOT IMPLEMENTED') # GUI functions. def set_action_text(self, text): self._action_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_status_text(self, text): self._status_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_output_text(self, text): self._algthm_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def append_output_text(self, text): self._algthm_output.append_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_image_overlays(self, condition): if len(self._target_filename) == 0: return initial_image = load_data_from_npz(self._target_filename, self._hyperparams['image_actuator'], str(condition), 'initial', 'image', default=np.zeros((1,1,3))) target_image = load_data_from_npz(self._target_filename, self._hyperparams['image_actuator'], str(condition), 'target', 'image', default=np.zeros((1,1,3))) self._image_visualizer.set_initial_image(initial_image, alpha=0.3) self._image_visualizer.set_target_image(target_image, alpha=0.3) def update(self, itr, algorithm, agent, traj_sample_lists, pol_sample_lists): # Plot Costs if algorithm.M == 1: # Update plot with each sample's cost (summed over time). costs = np.sum(algorithm.prev[0].cs, axis=1) else: # Update plot with each condition's mean sample cost (summed over time). costs = [np.mean(np.sum(algorithm.prev[m].cs, axis=1)) for m in range(algorithm.M)] self._cost_plotter.update(costs, t=itr) # Setup iteration data column titles and 3D visualization plot titles and legend if self._first_update: self.set_output_text(self._hyperparams['experiment_name']) condition_titles = '%3s | %8s' % ('', '') itr_data_fields = '%3s | %8s' % ('itr', 'avg_cost') for m in range(algorithm.M): condition_titles += ' | %8s %9s %-7d' % ('', 'condition', m) itr_data_fields += ' | %8s %8s %8s' % (' cost ', ' step ', 'entropy ') if algorithm.prev[0].pol_info is not None: condition_titles += ' %8s %8s' % ('', '') itr_data_fields += ' %8s %8s' % ('kl_div_i', 'kl_div_f') self.append_output_text(condition_titles) self.append_output_text(itr_data_fields) self._first_update = False # Print Iteration Data avg_cost = np.mean(costs) itr_data = '%3d | %8.2f' % (itr, avg_cost) for m in range(algorithm.M): cost = costs[m] step = algorithm.prev[m].step_mult entropy = 2*np.sum(np.log(np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar, axis1=1, axis2=2))) itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy) if algorithm.prev[0].pol_info is not None: kl_div_i = algorithm.prev[m].pol_info.prev_kl[0] kl_div_f = algorithm.prev[m].pol_info.prev_kl[-1] itr_data += ' %8.2f %8.2f' % (kl_div_i, kl_div_f) self.append_output_text(itr_data) if END_EFFECTOR_POINTS not in agent.x_data_types: # Skip plotting samples. self._traj_visualizer.draw() # this must be called explicitly self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend return # TODO(xinyutan) - this assumes that END_EFFECTOR_POINTS are in the # sample, which is not true for box2d. quick fix is above. # Calculate xlim, ylim, zlim for 3D visualizations from traj_sample_lists and pol_sample_lists # (this clips off LQG means/distributions that are not in the area of interest) all_eept = np.empty((0, 3)) sample_lists = traj_sample_lists + pol_sample_lists if pol_sample_lists else traj_sample_lists for sample_list in sample_lists: for sample in sample_list.get_samples(): ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] all_eept = np.r_[all_eept, ee_pt_i] min_xyz = np.amin(all_eept, axis=0) max_xyz = np.amax(all_eept, axis=0) xlim, ylim, zlim = (min_xyz[0], max_xyz[0]), (min_xyz[1], max_xyz[1]), (min_xyz[2], max_xyz[2]) # Plot 3D Visualizations for m in range(algorithm.M): # Clear previous plots self._traj_visualizer.clear(m) self._traj_visualizer.set_lim(i=m, xlim=xlim, ylim=ylim, zlim=zlim) # Linear Gaussian Controller Distributions (Red) mu, sigma = algorithm.traj_opt.forward(algorithm.prev[m].traj_distr, algorithm.prev[m].traj_info) eept_idx = agent.get_idx_x(END_EFFECTOR_POINTS) start, end = eept_idx[0], eept_idx[-1] mu_eept, sigma_eept = mu[:, start:end+1], sigma[:, start:end+1, start:end+1] for i in range(mu_eept.shape[1]/3): mu, sigma = mu_eept[:, 3*i+0:3*i+3], sigma_eept[:, 3*i+0:3*i+3, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_gaussian(i=m, mu=mu, sigma=sigma, edges=100, linestyle='-', linewidth=1.0, color='red', alpha=0.15, label='LG Controller Distributions') # Linear Gaussian Controller Means (Dark Red) for i in range(mu_eept.shape[1]/3): mu = mu_eept[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(i=m, points=mu, linestyle='None', marker='x', markersize=5.0, markeredgewidth=1.0, color=(0.5, 0, 0), alpha=1.0, label='LG Controller Means') # Trajectory Samples (Green) traj_samples = traj_sample_lists[m].get_samples() for sample in traj_samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color='green', label='Trajectory Samples') # Policy Samples (Blue) if pol_sample_lists is not None: pol_samples = pol_sample_lists[m].get_samples() for sample in pol_samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color='blue', label='Policy Samples') self._traj_visualizer.draw() # this must be called explicitly self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend def save_figure(self, filename): self._fig.savefig(filename)
def __init__(self, hyperparams): self._hyperparams = copy.deepcopy(common_config) self._hyperparams.update(copy.deepcopy(gps_training_config)) self._hyperparams.update(hyperparams) self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = '' # GPS Training Status. self.mode = 'run' # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] self._actions = {action._key: action for action in actions_arr} for key, action in self._actions.iteritems(): if key in self._hyperparams['keyboard_bindings']: action._kb = self._hyperparams['keyboard_bindings'][key] if key in self._hyperparams['ps3_bindings']: action._pb = self._hyperparams['ps3_bindings'][key] # GUI Components. plt.ion() plt.rcParams['toolbar'] = 'None' # Remove 's' keyboard shortcut for saving. plt.rcParams['keymap.save'] = '' self._fig = plt.figure(figsize=(12, 12)) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_axis = self._gs[0:2, 0:8] self._gs_action_output = self._gs[2:3, 0:4] self._gs_status_output = self._gs[3:4, 0:4] self._gs_cost_plotter = self._gs[2:4, 4:8] self._gs_algthm_output = self._gs[4:8, 0:8] self._gs_traj_visualizer = self._gs[8:16, 0:4] self._gs_image_visualizer = self._gs[8:16, 4:8] # Create GUI components. self._action_axis = ActionAxis(self._fig, self._gs_action_axis, 1, 4, self._actions, ps3_process_rate=self._hyperparams['ps3_process_rate'], ps3_topic=self._hyperparams['ps3_topic'], ps3_button=self._hyperparams['ps3_button'], inverted_ps3_button=self._hyperparams['inverted_ps3_button']) self._action_output = OutputAxis(self._fig, self._gs_action_output, border_on=True) self._status_output = OutputAxis(self._fig, self._gs_status_output, border_on=False) self._algthm_output = OutputAxis(self._fig, self._gs_algthm_output, max_display_size=15, log_filename=self._log_filename, fontsize=10, font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=(240, 240), rostopic=self._hyperparams['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) self.run_mode() # WARNING: Make sure the legend values in UPDATE match the below linestyles/markers and colors [self._traj_visualizer.set_title(m, 'Condition %d' % (m)) for m in range(self._hyperparams['conditions'])] self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw()
def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config['initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox(self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread(target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start()
class GPSTrainingGUI(object): def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config['initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox(self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread(target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start() # GPS Training functions def request_stop(self, event=None): self.request_mode('stop') def request_reset(self, event=None): self.request_mode('reset') def request_go(self, event=None): self.request_mode('go') def request_fail(self, event=None): self.request_mode('fail') def request_mode(self, request): """ Sets the request mode (stop, reset, go, fail). The request is read by gps_main before sampling, and the appropriate action is taken. """ self.mode = 'request' self.request = request self.set_action_text(self.request + ' requested') self.set_action_bgcolor(self._colors[self.request], alpha=0.2) def process_mode(self): """ Completes the current request, after it is first read by gps_main. Displays visual confirmation that the request was processed, displays any error messages, and then switches into mode 'run' or 'wait'. """ self.mode = 'process' self.set_action_text(self.request + ' processed') self.set_action_bgcolor(self._colors[self.request], alpha=1.0) if self.err_msg: self.set_action_text(self.request + ' processed' + '\nERROR: ' + self.err_msg) self.err_msg = None time.sleep(1.0) else: time.sleep(0.5) if self.request in ('stop', 'reset', 'fail'): self.wait_mode() elif self.request == 'go': self.run_mode() self.request = None def wait_mode(self): self.mode = 'wait' self.set_action_text('waiting') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def run_mode(self): self.mode = 'run' self.set_action_text('running') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def end_mode(self): self.mode = 'end' self.set_action_text('ended') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def estop(self, event=None): self.set_action_text('estop: NOT IMPLEMENTED') # GUI functions def set_action_text(self, text): self._action_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_status_text(self, text): self._status_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_output_text(self, text): self._algthm_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def append_output_text(self, text): self._algthm_output.append_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def start_display_calculating(self): self._calculating_run.set() def stop_display_calculating(self): self._calculating_run.clear() def set_image_overlays(self, condition): """ Sets up the image visualizer with what images to overlay if "overlay_initial_image" or "overlay_target_image" is pressed. """ if not config['image_on'] or not self._target_filename: return initial_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'initial', 'image', default=None) target_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'target', 'image', default=None) self._image_visualizer.set_initial_image(initial_image, alpha=config['image_overlay_alpha']) self._image_visualizer.set_target_image(target_image, alpha=config['image_overlay_alpha']) # Iteration update functions def update(self, itr, algorithm, agent, traj_sample_lists, pol_sample_lists): """ After each iteration, update the iteration data output, the cost plot, and the 3D trajectory visualizations (if end effector points exist). """ if self._first_update: policy_titles = pol_sample_lists != None self._output_column_titles(algorithm, policy_titles) self._first_update = False costs = [np.mean(np.sum(algorithm.prev[m].cs, axis=1)) for m in range(algorithm.M)] if algorithm._hyperparams['ioc']: gt_costs = [np.mean(np.sum(algorithm.prev[m].cgt, axis=1)) for m in range(algorithm.M)] self._update_iteration_data(itr, algorithm, gt_costs, pol_sample_lists) self._gt_cost_plotter.update(gt_costs, t=itr) else: self._update_iteration_data(itr, algorithm, costs, pol_sample_lists) self._cost_plotter.update(costs, t=itr) if END_EFFECTOR_POINTS in agent.x_data_types: self._update_trajectory_visualizations(algorithm, agent, traj_sample_lists, pol_sample_lists) self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend # import pdb; pdb.set_trace() def _output_column_titles(self, algorithm, policy_titles=False): """ Setup iteration data column titles: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ self.set_output_text(self._hyperparams['experiment_name']) if policy_titles: condition_titles = '%3s | %8s %12s' % ('', '', '') itr_data_fields = '%3s | %8s %12s' % ('itr', 'avg_cost', 'avg_pol_cost') else: condition_titles = '%3s | %8s' % ('', '') itr_data_fields = '%3s | %8s' % ('itr', 'avg_cost') for m in range(algorithm.M): condition_titles += ' | %8s %9s %-7d' % ('', 'condition', m) itr_data_fields += ' | %8s %8s %8s' % (' cost ', ' step ', 'entropy ') if algorithm.prev[0].pol_info is not None: condition_titles += ' %8s %8s' % ('', '') itr_data_fields += ' %8s %8s' % ('kl_div_i', 'kl_div_f') if algorithm._hyperparams['ioc'] and not algorithm._hyperparams['learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('kl_div') if algorithm._hyperparams['learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('mean_dist') if policy_titles: condition_titles += ' %8s %8s %8s' % ('', '', '') itr_data_fields += ' %8s %8s %8s' % ('pol_cost', 'kl_div_i', 'kl_div_f') self.append_output_text(condition_titles) self.append_output_text(itr_data_fields) def _update_iteration_data(self, itr, algorithm, costs, pol_sample_lists): """ Update iteration data information: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ avg_cost = np.mean(costs) if pol_sample_lists is not None: pol_costs = [np.mean([np.sum(algorithm.cost[m].eval(s)[0]) \ for s in pol_sample_lists[m]]) \ for m in range(algorithm.M)] itr_data = '%3d | %8.2f %12.2f' % (itr, avg_cost, np.mean(pol_costs)) else: itr_data = '%3d | %8.2f' % (itr, avg_cost) for m in range(algorithm.M): cost = costs[m] step = algorithm.prev[m].step_mult * algorithm.base_kl_step entropy = 2*np.sum(np.log(np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar, axis1=1, axis2=2))) itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy) if algorithm.prev[0].pol_info is not None: kl_div_i = algorithm.prev[m].pol_info.prev_kl[0] kl_div_f = algorithm.prev[m].pol_info.prev_kl[-1] itr_data += ' %8.2f %8.2f' % (kl_div_i, kl_div_f) if algorithm._hyperparams['ioc'] and not algorithm._hyperparams['learning_from_prior']: itr_data += ' %8.2f' % (algorithm.kl_div[itr][m]) if algorithm._hyperparams['learning_from_prior']: itr_data += ' %8.2f' % (algorithm.dists_to_target[itr][m]) if pol_sample_lists is not None: kl_div_i = algorithm.cur[m].pol_info.init_kl.mean() kl_div_f = algorithm.cur[m].pol_info.prev_kl.mean() itr_data += ' %8.2f %8.2f %8.2f' % (pol_costs[m], kl_div_i, kl_div_f) self.append_output_text(itr_data) def _update_trajectory_visualizations(self, algorithm, agent, traj_sample_lists, pol_sample_lists): """ Update 3D trajectory visualizations information: the trajectory samples, policy samples, and linear Gaussian controller means and covariances. """ xlim, ylim, zlim = self._calculate_3d_axis_limits(traj_sample_lists, pol_sample_lists) for m in range(algorithm.M): self._traj_visualizer.clear(m) self._traj_visualizer.set_lim(i=m, xlim=xlim, ylim=ylim, zlim=zlim) self._update_samples_plots(traj_sample_lists, m, 'green', 'Trajectory Samples') self._update_linear_gaussian_controller_plots(algorithm, agent, m) if pol_sample_lists: self._update_samples_plots(pol_sample_lists, m, 'blue', 'Policy Samples') self._traj_visualizer.draw() # this must be called explicitly def _calculate_3d_axis_limits(self, traj_sample_lists, pol_sample_lists): """ Calculate the 3D axis limits shared between trajectory plots, based on the minimum and maximum xyz values across all samples. """ all_eept = np.empty((0, 3)) sample_lists = traj_sample_lists if pol_sample_lists: sample_lists += traj_sample_lists for sample_list in sample_lists: for sample in sample_list.get_samples(): ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] all_eept = np.r_[all_eept, ee_pt_i] min_xyz = np.amin(all_eept, axis=0) max_xyz = np.amax(all_eept, axis=0) xlim = buffered_axis_limits(min_xyz[0], max_xyz[0], buffer_factor=1.25) ylim = buffered_axis_limits(min_xyz[1], max_xyz[1], buffer_factor=1.25) zlim = buffered_axis_limits(min_xyz[2], max_xyz[2], buffer_factor=1.25) return xlim, ylim, zlim def _update_linear_gaussian_controller_plots(self, algorithm, agent, m): """ Update the linear Guassian controller plots with iteration data, for the mean and covariances of the end effector points. """ # Calculate mean and covariance for end effector points eept_idx = agent.get_idx_x(END_EFFECTOR_POINTS) start, end = eept_idx[0], eept_idx[-1] mu, sigma = algorithm.traj_opt.forward(algorithm.prev[m].traj_distr, algorithm.prev[m].traj_info) mu_eept, sigma_eept = mu[:, start:end+1], sigma[:, start:end+1, start:end+1] # Linear Gaussian Controller Distributions (Red) for i in range(mu_eept.shape[1]/3): mu, sigma = mu_eept[:, 3*i+0:3*i+3], sigma_eept[:, 3*i+0:3*i+3, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_gaussian(i=m, mu=mu, sigma=sigma, edges=100, linestyle='-', linewidth=1.0, color='red', alpha=0.15, label='LG Controller Distributions') # Linear Gaussian Controller Means (Dark Red) for i in range(mu_eept.shape[1]/3): mu = mu_eept[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(i=m, points=mu, linestyle='None', marker='x', markersize=5.0, markeredgewidth=1.0, color=(0.5, 0, 0), alpha=1.0, label='LG Controller Means') def _update_samples_plots(self, sample_lists, m, color, label): """ Update the samples plots with iteration data, for the trajectory samples and the policy samples. """ samples = sample_lists[m].get_samples() for sample in samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color=color, label=label) def save_figure(self, filename): self._fig.savefig(filename)
def __init__(self, hyperparams, agent): self._hyperparams = copy.deepcopy(common_config) self._hyperparams.update(copy.deepcopy(target_setup_config)) self._hyperparams.update(hyperparams) self._agent = agent self._log_filename = self._hyperparams['log_filename'] self._target_filename = self._hyperparams['target_filename'] self._num_targets = self._hyperparams['num_targets'] self._actuator_types = self._hyperparams['actuator_types'] self._actuator_names = self._hyperparams['actuator_names'] self._num_actuators = len(self._actuator_types) # Target Setup Status. self._target_number = 0 self._actuator_number = 0 self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self._initial_position = ('unknown', 'unknown', 'unknown') self._target_position = ('unknown', 'unknown', 'unknown') self._initial_image = None self._target_image = None self._mannequin_mode = False # Actions. actions_arr = [ Action('ptn', 'prev_target_number', self.prev_target_number, axis_pos=0), Action('ntn', 'next_target_number', self.next_target_number, axis_pos=1), Action('pat', 'prev_actuator_type', self.prev_actuator_type, axis_pos=2), Action('nat', 'next_actuator_type', self.next_actuator_type, axis_pos=3), Action('sip', 'set_initial_position', self.set_initial_position, axis_pos=4), Action('stp', 'set_target_position', self.set_target_position, axis_pos=5), Action('sii', 'set_initial_image', self.set_initial_image, axis_pos=6), Action('sti', 'set_target_image', self.set_target_image, axis_pos=7), Action('mti', 'move_to_initial', self.move_to_initial, axis_pos=8), Action('mtt', 'move_to_target', self.move_to_target, axis_pos=9), Action('rc', 'relax_controller', self.relax_controller, axis_pos=10), Action('mm', 'mannequin_mode', self.mannequin_mode, axis_pos=11), ] #TODO: Is it possible to merge this code with # GPSTrainingGUI.__init__? self._actions = {action._key: action for action in actions_arr} for key, action in self._actions.iteritems(): if key in self._hyperparams['keyboard_bindings']: action._kb = self._hyperparams['keyboard_bindings'][key] if key in self._hyperparams['ps3_bindings']: action._pb = self._hyperparams['ps3_bindings'][key] # GUI Components. plt.ion() plt.rcParams['toolbar'] = 'None' # Remove 's' keyboard shortcut for saving. plt.rcParams['keymap.save'] = '' self._fig = plt.figure(figsize=(12, 12)) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(4, 4) self._gs_action_axis = self._gs[0:1, 0:4] self._gs_target_output = self._gs[1:3, 0:2] self._gs_initial_image_visualizer = self._gs[3:4, 0:1] self._gs_target_image_visualizer = self._gs[3:4, 1:2] self._gs_action_output = self._gs[1:2, 2:4] self._gs_image_visualizer = self._gs[2:4, 2:4] # Create GUI components. self._action_axis = ActionAxis(self._fig, self._gs_action_axis, 3, 4, self._actions, ps3_process_rate=self._hyperparams['ps3_process_rate'], ps3_topic=self._hyperparams['ps3_topic'], ps3_button=self._hyperparams['ps3_button'], inverted_ps3_button=self._hyperparams['inverted_ps3_button']) self._target_output = OutputAxis(self._fig, self._gs_target_output, log_filename=self._log_filename, fontsize=10) self._initial_image_visualizer = ImageVisualizer(self._fig, self._gs_initial_image_visualizer) self._target_image_visualizer = ImageVisualizer(self._fig, self._gs_target_image_visualizer) self._action_output = OutputAxis(self._fig, self._gs_action_output) self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=(240, 240), rostopic=self._hyperparams['image_topic'], show_overlay_buttons=True) # Setup GUI components. self.reload_positions() self.update_target_text() self.set_action_text('Press an action to begin.') self.set_action_bgcolor('white') self._fig.canvas.draw()
class TargetSetupGUI(object): """ Target setup GUI class. """ def __init__(self, hyperparams, agent): self._hyperparams = copy.deepcopy(common_config) self._hyperparams.update(copy.deepcopy(target_setup_config)) self._hyperparams.update(hyperparams) self._agent = agent self._log_filename = self._hyperparams['log_filename'] self._target_filename = self._hyperparams['target_filename'] self._num_targets = self._hyperparams['num_targets'] self._actuator_types = self._hyperparams['actuator_types'] self._actuator_names = self._hyperparams['actuator_names'] self._num_actuators = len(self._actuator_types) # Target Setup Status. self._target_number = 0 self._actuator_number = 0 self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self._initial_position = ('unknown', 'unknown', 'unknown') self._target_position = ('unknown', 'unknown', 'unknown') self._initial_image = None self._target_image = None self._mannequin_mode = False # Actions. actions_arr = [ Action('ptn', 'prev_target_number', self.prev_target_number, axis_pos=0), Action('ntn', 'next_target_number', self.next_target_number, axis_pos=1), Action('pat', 'prev_actuator_type', self.prev_actuator_type, axis_pos=2), Action('nat', 'next_actuator_type', self.next_actuator_type, axis_pos=3), Action('sip', 'set_initial_position', self.set_initial_position, axis_pos=4), Action('stp', 'set_target_position', self.set_target_position, axis_pos=5), Action('sii', 'set_initial_image', self.set_initial_image, axis_pos=6), Action('sti', 'set_target_image', self.set_target_image, axis_pos=7), Action('mti', 'move_to_initial', self.move_to_initial, axis_pos=8), Action('mtt', 'move_to_target', self.move_to_target, axis_pos=9), Action('rc', 'relax_controller', self.relax_controller, axis_pos=10), Action('mm', 'mannequin_mode', self.mannequin_mode, axis_pos=11), ] #TODO: Is it possible to merge this code with # GPSTrainingGUI.__init__? self._actions = {action._key: action for action in actions_arr} for key, action in self._actions.iteritems(): if key in self._hyperparams['keyboard_bindings']: action._kb = self._hyperparams['keyboard_bindings'][key] if key in self._hyperparams['ps3_bindings']: action._pb = self._hyperparams['ps3_bindings'][key] # GUI Components. plt.ion() plt.rcParams['toolbar'] = 'None' # Remove 's' keyboard shortcut for saving. plt.rcParams['keymap.save'] = '' self._fig = plt.figure(figsize=(12, 12)) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(4, 4) self._gs_action_axis = self._gs[0:1, 0:4] self._gs_target_output = self._gs[1:3, 0:2] self._gs_initial_image_visualizer = self._gs[3:4, 0:1] self._gs_target_image_visualizer = self._gs[3:4, 1:2] self._gs_action_output = self._gs[1:2, 2:4] self._gs_image_visualizer = self._gs[2:4, 2:4] # Create GUI components. self._action_axis = ActionAxis(self._fig, self._gs_action_axis, 3, 4, self._actions, ps3_process_rate=self._hyperparams['ps3_process_rate'], ps3_topic=self._hyperparams['ps3_topic'], ps3_button=self._hyperparams['ps3_button'], inverted_ps3_button=self._hyperparams['inverted_ps3_button']) self._target_output = OutputAxis(self._fig, self._gs_target_output, log_filename=self._log_filename, fontsize=10) self._initial_image_visualizer = ImageVisualizer(self._fig, self._gs_initial_image_visualizer) self._target_image_visualizer = ImageVisualizer(self._fig, self._gs_target_image_visualizer) self._action_output = OutputAxis(self._fig, self._gs_action_output) self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=(240, 240), rostopic=self._hyperparams['image_topic'], show_overlay_buttons=True) # Setup GUI components. self.reload_positions() self.update_target_text() self.set_action_text('Press an action to begin.') self.set_action_bgcolor('white') self._fig.canvas.draw() # Target Setup Functions. # TODO: Add docstrings to these methods. def prev_target_number(self, event=None): self.set_action_status_message('prev_target_number', 'requested') self._target_number = (self._target_number - 1) % self._num_targets self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('prev_target_number', 'completed', message='target number = %d' % self._target_number) def next_target_number(self, event=None): self.set_action_status_message('next_target_number', 'requested') self._target_number = (self._target_number + 1) % self._num_targets self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('next_target_number', 'completed', message='target number = %d' % self._target_number) def prev_actuator_type(self, event=None): self.set_action_status_message('prev_actuator_type', 'requested') self._actuator_number = (self._actuator_number-1) % self._num_actuators self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('prev_actuator_type', 'completed', message='actuator name = %s' % self._actuator_name) def next_actuator_type(self, event=None): self.set_action_status_message('next_actuator_type', 'requested') self._actuator_number = (self._actuator_number+1) % self._num_actuators self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('next_actuator_type', 'completed', message='actuator name = %s' % self._actuator_name) def set_initial_position(self, event=None): self.set_action_status_message('set_initial_position', 'requested') try: sample = self._agent.get_data(arm=self._actuator_type) except TimeoutException as e: self.set_action_status_message('set_initial_position', 'failed', message='TimeoutException while retrieving sample') return ja = sample.get(JOINT_ANGLES) ee_pos = sample.get(END_EFFECTOR_POSITIONS) ee_rot = sample.get(END_EFFECTOR_ROTATIONS) self._initial_position = (ja, ee_pos, ee_rot) save_pose_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial', self._initial_position) self.update_target_text() self.set_action_status_message('set_initial_position', 'completed', message='initial position =\n %s' % self.position_to_str(self._initial_position)) def set_target_position(self, event=None): self.set_action_status_message('set_target_position', 'requested') self.set_action_bgcolor('green', alpha=0.2) try: sample = self._agent.get_data(arm=self._actuator_type) except TimeoutException as e: self.set_action_status_message('set_target_position', 'failed', message='TimeoutException while retrieving sample') return ja = sample.get(JOINT_ANGLES) ee_pos = sample.get(END_EFFECTOR_POSITIONS) ee_rot = sample.get(END_EFFECTOR_ROTATIONS) self._target_position = (ja, ee_pos, ee_rot) save_pose_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target', self._target_position) self.update_target_text() self.set_action_status_message('set_target_position', 'completed', message='target position =\n %s' % self.position_to_str(self._target_position)) def set_initial_image(self, event=None): self.set_action_status_message('set_initial_image', 'requested') self._initial_image = self._image_visualizer.get_current_image() if self._initial_image is None: self.set_action_status_message('set_initial_image', 'failed', message='no image available') return save_data_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial', 'image', self._initial_image) self.update_target_text() self.set_action_status_message('set_initial_image', 'completed', message='initial image =\n %s' % str(self._initial_image)) def set_target_image(self, event=None): self.set_action_status_message('set_target_image', 'requested') self._target_image = self._image_visualizer.get_current_image() if self._target_image is None: self.set_action_status_message('set_target_image', 'failed', message='no image available') return save_data_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target', 'image', self._target_image) self.update_target_text() self.set_action_status_message('set_target_image', 'completed', message='target image =\n %s' % str(self._target_image)) def move_to_initial(self, event=None): ja = self._initial_position[0] self.set_action_status_message('move_to_initial', 'requested') self._agent.reset_arm(self._actuator_type, JOINT_SPACE, ja) self.set_action_status_message('move_to_initial', 'completed', message='initial position: %s' % str(ja)) def move_to_target(self, event=None): ja = self._target_position[0] self.set_action_status_message('move_to_target', 'requested') self._agent.reset_arm(self._actuator_type, JOINT_SPACE, ja) self.set_action_status_message('move_to_target', 'completed', message='target position: %s' % str(ja)) def relax_controller(self, event=None): self.set_action_status_message('relax_controller', 'requested') self._agent.relax_arm(self._actuator_type) self.set_action_status_message('relax_controller', 'completed', message='actuator name: %s' % self._actuator_name) def mannequin_mode(self, event=None): if not self._mannequin_mode: self.set_action_status_message('mannequin_mode', 'requested') subprocess.call(['roslaunch', 'pr2_mannequin_mode', 'pr2_mannequin_mode.launch']) self._mannequin_mode = True self.set_action_status_message('mannequin_mode', 'completed', message='mannequin mode toggled on') else: self.set_action_status_message('mannequin_mode', 'requested') subprocess.call(['roslaunch', 'gps_agent_pkg', 'pr2_real.launch']) self._mannequin_mode = False self.set_action_status_message('mannequin_mode', 'completed', message='mannequin mode toggled off') # GUI functions. def update_target_text(self): np.set_printoptions(precision=3, suppress=True) text = ( 'target number = %s\n' % str(self._target_number) + 'actuator name = %s\n' % str(self._actuator_name) + '\ninitial position\n%s' % self.position_to_str(self._initial_position) + '\ntarget position\n%s' % self.position_to_str(self._target_position) + '\ninitial image (left) =\n%s\n' % str(self._initial_image) + '\ntarget image (right) =\n%s\n' % str(self._target_image) ) self._target_output.set_text(text) self._initial_image_visualizer.update(self._initial_image) self._target_image_visualizer.update(self._target_image) self._image_visualizer.set_initial_image(self._initial_image, alpha=0.3) self._image_visualizer.set_target_image(self._target_image, alpha=0.3) def position_to_str(self, position): np.set_printoptions(precision=3, suppress=True) ja, ee_pos, ee_rot = position return ('joint angles =\n%s\n' % ja + 'end effector positions =\n%s\n' % ee_pos + 'end effector rotations =\n%s\n' % ee_rot) def set_action_status_message(self, action, status, message=None): text = action + ': ' + status if message: text += '\n\n' + message self.set_action_text(text) if status == 'requested': self.set_action_bgcolor('yellow') elif status == 'completed': self.set_action_bgcolor('green') elif status == 'failed': self.set_action_bgcolor('red') def set_action_text(self, text=''): self._action_output.set_text(text) def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) def reload_positions(self): self._initial_position = load_pose_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial') self._target_position = load_pose_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target') self._initial_image = load_data_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial', 'image', default=None) self._target_image = load_data_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target', 'image', default=None)
textbox = Textbox(fig, gs[1], max_display_size=10, log_filename=None) run_demo(demo_textbox) # Image Visualizer def demo_image_visualizer(): im = np.zeros((5, 5, 3)) while True: i = random.randint(0, im.shape[0] - 1) j = random.randint(0, im.shape[1] - 1) k = random.randint(0, im.shape[2] - 1) im[i, j, k] = (im[i, j, k] + random.randint(0, 255)) % 256 image_visualizer.update(im) time.sleep(5e-3) image_visualizer = ImageVisualizer(fig, gs[2], cropsize=(3, 3)) run_demo(demo_image_visualizer) # Realtime Plotter def demo_realtime_plotter(): i, j = 0, 0 while True: i += random.randint(-10, 10) j += random.randint(-10, 10) data = [i, j, i + j, i - j] mean = np.mean(data) realtime_plotter.update(data + [mean]) time.sleep(5e-3) realtime_plotter = RealtimePlotter(fig, gs[3], labels=['i', 'j', 'i+j', 'i-j', 'mean'],