class GPSTrainingGUI(object): def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config[ 'initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox( self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D( self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer( self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread( target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start() # GPS Training functions def request_stop(self, event=None): self.request_mode('stop') def request_reset(self, event=None): self.request_mode('reset') def request_go(self, event=None): self.request_mode('go') def request_fail(self, event=None): self.request_mode('fail') def request_mode(self, request): """ Sets the request mode (stop, reset, go, fail). The request is read by gps_main before sampling, and the appropriate action is taken. """ self.mode = 'request' self.request = request self.set_action_text(self.request + ' requested') self.set_action_bgcolor(self._colors[self.request], alpha=0.2) def process_mode(self): """ Completes the current request, after it is first read by gps_main. Displays visual confirmation that the request was processed, displays any error messages, and then switches into mode 'run' or 'wait'. """ self.mode = 'process' self.set_action_text(self.request + ' processed') self.set_action_bgcolor(self._colors[self.request], alpha=1.0) if self.err_msg: self.set_action_text(self.request + ' processed' + '\nERROR: ' + self.err_msg) self.err_msg = None time.sleep(1.0) else: time.sleep(0.5) if self.request in ('stop', 'reset', 'fail'): self.wait_mode() elif self.request == 'go': self.run_mode() self.request = None def wait_mode(self): self.mode = 'wait' self.set_action_text('waiting') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def run_mode(self): self.mode = 'run' self.set_action_text('running') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def end_mode(self): self.mode = 'end' self.set_action_text('ended') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def estop(self, event=None): self.set_action_text('estop: NOT IMPLEMENTED') # GUI functions def set_action_text(self, text): self._action_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_status_text(self, text): self._status_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_output_text(self, text): self._algthm_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def append_output_text(self, text): self._algthm_output.append_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def start_display_calculating(self): self._calculating_run.set() def stop_display_calculating(self): self._calculating_run.clear() def set_image_overlays(self, condition): """ Sets up the image visualizer with what images to overlay if "overlay_initial_image" or "overlay_target_image" is pressed. """ if not config['image_on'] or not self._target_filename: return initial_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'initial', 'image', default=None) target_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'target', 'image', default=None) self._image_visualizer.set_initial_image( initial_image, alpha=config['image_overlay_alpha']) self._image_visualizer.set_target_image( target_image, alpha=config['image_overlay_alpha']) # Iteration update functions def update(self, itr, algorithm, agent, traj_sample_lists, pol_sample_lists): """ After each iteration, update the iteration data output, the cost plot, and the 3D trajectory visualizations (if end effector points exist). """ if self._first_update: policy_titles = pol_sample_lists != None self._output_column_titles(algorithm, policy_titles) self._first_update = False costs = [ np.mean(np.sum(algorithm.prev[m].cs, axis=1)) for m in range(algorithm.M) ] if algorithm._hyperparams['ioc']: gt_costs = [ np.mean(np.sum(algorithm.prev[m].cgt, axis=1)) for m in range(algorithm.M) ] self._update_iteration_data(itr, algorithm, gt_costs, pol_sample_lists) self._gt_cost_plotter.update(gt_costs, t=itr) else: self._update_iteration_data(itr, algorithm, costs, pol_sample_lists) self._cost_plotter.update(costs, t=itr) if END_EFFECTOR_POINTS in agent.x_data_types: self._update_trajectory_visualizations(algorithm, agent, traj_sample_lists, pol_sample_lists) self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend # import pdb; pdb.set_trace() def _output_column_titles(self, algorithm, policy_titles=False): """ Setup iteration data column titles: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ self.set_output_text(self._hyperparams['experiment_name']) if policy_titles: condition_titles = '%3s | %8s %12s' % ('', '', '') itr_data_fields = '%3s | %8s %12s' % ('itr', 'avg_cost', 'avg_pol_cost') else: condition_titles = '%3s | %8s' % ('', '') itr_data_fields = '%3s | %8s' % ('itr', 'avg_cost') for m in range(algorithm.M): condition_titles += ' | %8s %9s %-7d' % ('', 'condition', m) itr_data_fields += ' | %8s %8s %8s' % (' cost ', ' step ', 'entropy ') if algorithm.prev[0].pol_info is not None: condition_titles += ' %8s %8s' % ('', '') itr_data_fields += ' %8s %8s' % ('kl_div_i', 'kl_div_f') if algorithm._hyperparams['ioc'] and not algorithm._hyperparams[ 'learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('kl_div') if algorithm._hyperparams['learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('mean_dist') if policy_titles: condition_titles += ' %8s %8s %8s' % ('', '', '') itr_data_fields += ' %8s %8s %8s' % ('pol_cost', 'kl_div_i', 'kl_div_f') self.append_output_text(condition_titles) self.append_output_text(itr_data_fields) def _update_iteration_data(self, itr, algorithm, costs, pol_sample_lists): """ Update iteration data information: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ avg_cost = np.mean(costs) if pol_sample_lists is not None: pol_costs = [np.mean([np.sum(algorithm.cost[m].eval(s)[0]) \ for s in pol_sample_lists[m]]) \ for m in range(algorithm.M)] itr_data = '%3d | %8.2f %12.2f' % (itr, avg_cost, np.mean(pol_costs)) else: itr_data = '%3d | %8.2f' % (itr, avg_cost) for m in range(algorithm.M): cost = costs[m] step = algorithm.prev[m].step_mult * algorithm.base_kl_step entropy = 2 * np.sum( np.log( np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar, axis1=1, axis2=2))) itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy) if algorithm.prev[0].pol_info is not None: kl_div_i = algorithm.prev[m].pol_info.prev_kl[0] kl_div_f = algorithm.prev[m].pol_info.prev_kl[-1] itr_data += ' %8.2f %8.2f' % (kl_div_i, kl_div_f) if algorithm._hyperparams['ioc'] and not algorithm._hyperparams[ 'learning_from_prior']: itr_data += ' %8.2f' % (algorithm.kl_div[itr][m]) if algorithm._hyperparams['learning_from_prior']: itr_data += ' %8.2f' % (algorithm.dists_to_target[itr][m]) if pol_sample_lists is not None: kl_div_i = algorithm.cur[m].pol_info.init_kl.mean() kl_div_f = algorithm.cur[m].pol_info.prev_kl.mean() itr_data += ' %8.2f %8.2f %8.2f' % (pol_costs[m], kl_div_i, kl_div_f) self.append_output_text(itr_data) def _update_trajectory_visualizations(self, algorithm, agent, traj_sample_lists, pol_sample_lists): """ Update 3D trajectory visualizations information: the trajectory samples, policy samples, and linear Gaussian controller means and covariances. """ xlim, ylim, zlim = self._calculate_3d_axis_limits( traj_sample_lists, pol_sample_lists) for m in range(algorithm.M): self._traj_visualizer.clear(m) self._traj_visualizer.set_lim(i=m, xlim=xlim, ylim=ylim, zlim=zlim) self._update_samples_plots(traj_sample_lists, m, 'green', 'Trajectory Samples') self._update_linear_gaussian_controller_plots(algorithm, agent, m) if pol_sample_lists: self._update_samples_plots(pol_sample_lists, m, 'blue', 'Policy Samples') self._traj_visualizer.draw() # this must be called explicitly def _calculate_3d_axis_limits(self, traj_sample_lists, pol_sample_lists): """ Calculate the 3D axis limits shared between trajectory plots, based on the minimum and maximum xyz values across all samples. """ all_eept = np.empty((0, 3)) sample_lists = traj_sample_lists if pol_sample_lists: sample_lists += traj_sample_lists for sample_list in sample_lists: for sample in sample_list.get_samples(): ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1] / 3): ee_pt_i = ee_pt[:, 3 * i + 0:3 * i + 3] all_eept = np.r_[all_eept, ee_pt_i] min_xyz = np.amin(all_eept, axis=0) max_xyz = np.amax(all_eept, axis=0) xlim = buffered_axis_limits(min_xyz[0], max_xyz[0], buffer_factor=1.25) ylim = buffered_axis_limits(min_xyz[1], max_xyz[1], buffer_factor=1.25) zlim = buffered_axis_limits(min_xyz[2], max_xyz[2], buffer_factor=1.25) return xlim, ylim, zlim def _update_linear_gaussian_controller_plots(self, algorithm, agent, m): """ Update the linear Guassian controller plots with iteration data, for the mean and covariances of the end effector points. """ # Calculate mean and covariance for end effector points eept_idx = agent.get_idx_x(END_EFFECTOR_POINTS) start, end = eept_idx[0], eept_idx[-1] mu, sigma = algorithm.traj_opt.forward(algorithm.prev[m].traj_distr, algorithm.prev[m].traj_info) mu_eept, sigma_eept = mu[:, start:end + 1], sigma[:, start:end + 1, start:end + 1] # Linear Gaussian Controller Distributions (Red) for i in range(mu_eept.shape[1] / 3): mu, sigma = mu_eept[:, 3 * i + 0:3 * i + 3], sigma_eept[:, 3 * i + 0:3 * i + 3, 3 * i + 0:3 * i + 3] self._traj_visualizer.plot_3d_gaussian( i=m, mu=mu, sigma=sigma, edges=100, linestyle='-', linewidth=1.0, color='red', alpha=0.15, label='LG Controller Distributions') # Linear Gaussian Controller Means (Dark Red) for i in range(mu_eept.shape[1] / 3): mu = mu_eept[:, 3 * i + 0:3 * i + 3] self._traj_visualizer.plot_3d_points(i=m, points=mu, linestyle='None', marker='x', markersize=5.0, markeredgewidth=1.0, color=(0.5, 0, 0), alpha=1.0, label='LG Controller Means') def _update_samples_plots(self, sample_lists, m, color, label): """ Update the samples plots with iteration data, for the trajectory samples and the policy samples. """ samples = sample_lists[m].get_samples() for sample in samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1] / 3): ee_pt_i = ee_pt[:, 3 * i + 0:3 * i + 3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color=color, label=label) def save_figure(self, filename): self._fig.savefig(filename)
def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config[ 'initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox( self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D( self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer( self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread( target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start()
class TargetSetupGUI(object): def __init__(self, hyperparams, agent): self._hyperparams = hyperparams self._agent = agent self._log_filename = self._hyperparams['log_filename'] self._target_filename = self._hyperparams['target_filename'] self._num_targets = config['num_targets'] self._actuator_types = config['actuator_types'] self._actuator_names = config['actuator_names'] self._num_actuators = len(self._actuator_types) # Target Setup Status. self._target_number = 0 self._actuator_number = 0 self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self._initial_position = ('unknown', 'unknown', 'unknown') self._target_position = ('unknown', 'unknown', 'unknown') self._initial_image = None self._target_image = None self._mannequin_mode = False self._mm_process = None # Actions. actions_arr = [ Action('ptn', 'prev_target_number', self.prev_target_number, axis_pos=0), Action('ntn', 'next_target_number', self.next_target_number, axis_pos=1), Action('pat', 'prev_actuator_type', self.prev_actuator_type, axis_pos=2), Action('nat', 'next_actuator_type', self.next_actuator_type, axis_pos=3), Action('sip', 'set_initial_position', self.set_initial_position, axis_pos=4), Action('stp', 'set_target_position', self.set_target_position, axis_pos=5), Action('sii', 'set_initial_image', self.set_initial_image, axis_pos=6), Action('sti', 'set_target_image', self.set_target_image, axis_pos=7), Action('mti', 'move_to_initial', self.move_to_initial, axis_pos=8), Action('mtt', 'move_to_target', self.move_to_target, axis_pos=9), Action('rc', 'relax_controller', self.relax_controller, axis_pos=10), Action('mm', 'mannequin_mode', self.mannequin_mode, axis_pos=11), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(4, 4) self._gs_action_panel = self._gs[0:1, 0:4] if config['image_on']: self._gs_target_output = self._gs[1:3, 0:2] self._gs_initial_image_visualizer = self._gs[3:4, 0:1] self._gs_target_image_visualizer = self._gs[3:4, 1:2] self._gs_action_output = self._gs[1:2, 2:4] self._gs_image_visualizer = self._gs[2:4, 2:4] else: self._gs_target_output = self._gs[1:4, 0:2] self._gs_action_output = self._gs[1:4, 2:4] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 3, 4, actions_arr) self._target_output = Textbox( self._fig, self._gs_target_output, log_filename=self._log_filename, fontsize=config['target_output_fontsize']) self._action_output = Textbox(self._fig, self._gs_action_output) if config['image_on']: self._initial_image_visualizer = ImageVisualizer( self._fig, self._gs_initial_image_visualizer) self._target_image_visualizer = ImageVisualizer( self._fig, self._gs_target_image_visualizer) self._image_visualizer = ImageVisualizer( self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self.reload_positions() self.update_target_text() self.set_action_text('Press an action to begin.') self.set_action_bgcolor('white') self._fig.canvas.draw() # Target Setup Functions. def prev_target_number(self, event=None): self.set_action_status_message('prev_target_number', 'requested') self._target_number = (self._target_number - 1) % self._num_targets self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('prev_target_number', 'completed', message='target number = %d' % self._target_number) def next_target_number(self, event=None): self.set_action_status_message('next_target_number', 'requested') self._target_number = (self._target_number + 1) % self._num_targets self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('next_target_number', 'completed', message='target number = %d' % self._target_number) def prev_actuator_type(self, event=None): self.set_action_status_message('prev_actuator_type', 'requested') self._actuator_number = (self._actuator_number - 1) % self._num_actuators self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('prev_actuator_type', 'completed', message='actuator name = %s' % self._actuator_name) def next_actuator_type(self, event=None): self.set_action_status_message('next_actuator_type', 'requested') self._actuator_number = (self._actuator_number + 1) % self._num_actuators self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self.reload_positions() self.update_target_text() self.set_action_text() self.set_action_status_message('next_actuator_type', 'completed', message='actuator name = %s' % self._actuator_name) def set_initial_position(self, event=None): self.set_action_status_message('set_initial_position', 'requested') try: sample = self._agent.get_data(arm=self._actuator_type) except TimeoutException: self.set_action_status_message( 'set_initial_position', 'failed', message='TimeoutException while retrieving sample') return ja = sample.get(JOINT_ANGLES) ee_pos = sample.get(END_EFFECTOR_POSITIONS) ee_rot = sample.get(END_EFFECTOR_ROTATIONS) self._initial_position = (ja, ee_pos, ee_rot) save_pose_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial', self._initial_position) self.update_target_text() self.set_action_status_message('set_initial_position', 'completed', message='initial position =\n %s' % \ self.position_to_str(self._initial_position)) def set_target_position(self, event=None): self.set_action_status_message('set_target_position', 'requested') try: sample = self._agent.get_data(arm=self._actuator_type) except TimeoutException: self.set_action_status_message( 'set_target_position', 'failed', message='TimeoutException while retrieving sample') return ja = sample.get(JOINT_ANGLES) ee_pos = sample.get(END_EFFECTOR_POSITIONS) ee_rot = sample.get(END_EFFECTOR_ROTATIONS) self._target_position = (ja, ee_pos, ee_rot) save_pose_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target', self._target_position) self.update_target_text() self.set_action_status_message('set_target_position', 'completed', message='target position =\n %s' % \ self.position_to_str(self._target_position)) def set_initial_image(self, event=None): self.set_action_status_message('set_initial_image', 'requested') self._initial_image = self._image_visualizer.get_current_image() if self._initial_image is None: self.set_action_status_message('set_initial_image', 'failed', message='no image available') return save_data_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial', 'image', self._initial_image) self.update_target_text() self.set_action_status_message('set_initial_image', 'completed', message='initial image =\n %s' % str(self._initial_image)) def set_target_image(self, event=None): self.set_action_status_message('set_target_image', 'requested') self._target_image = self._image_visualizer.get_current_image() if self._target_image is None: self.set_action_status_message('set_target_image', 'failed', message='no image available') return save_data_to_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target', 'image', self._target_image) self.update_target_text() self.set_action_status_message('set_target_image', 'completed', message='target image =\n %s' % str(self._target_image)) def move_to_initial(self, event=None): ja = self._initial_position[0] self.set_action_status_message('move_to_initial', 'requested') self._agent.reset_arm(self._actuator_type, JOINT_SPACE, ja.T) self.set_action_status_message('move_to_initial', 'completed', message='initial position: %s' % str(ja)) def move_to_target(self, event=None): ja = self._target_position[0] self.set_action_status_message('move_to_target', 'requested') self._agent.reset_arm(self._actuator_type, JOINT_SPACE, ja.T) self.set_action_status_message('move_to_target', 'completed', message='target position: %s' % str(ja)) def relax_controller(self, event=None): self.set_action_status_message('relax_controller', 'requested') self._agent.relax_arm(self._actuator_type) self.set_action_status_message('relax_controller', 'completed', message='actuator name: %s' % self._actuator_name) def mannequin_mode(self, event=None): """ Calls "roslaunch pr2_mannequin_mode pr2_mannequin_mode.launch" (only works for the PR2 robot). """ if not self._mannequin_mode: self.set_action_status_message('mannequin_mode', 'requested') subprocess.Popen([ 'rosrun', 'pr2_controller_manager', 'pr2_controller_manager', 'stop', 'GPSPR2Plugin' ], stdout=DEVNULL) self._mm_process = subprocess.Popen([ 'roslaunch', 'pr2_mannequin_mode', 'pr2_mannequin_mode.launch' ], stdout=DEVNULL) self._mannequin_mode = True self.set_action_status_message('mannequin_mode', 'completed', message='mannequin mode toggled on') else: self.set_action_status_message('mannequin_mode', 'requested') self._mm_process.send_signal(signal.SIGINT) subprocess.Popen([ 'rosrun', 'pr2_controller_manager', 'pr2_controller_manager', 'start', 'GPSPR2Plugin' ], stdout=DEVNULL) self._mannequin_mode = False self.set_action_status_message( 'mannequin_mode', 'completed', message='mannequin mode toggled off') # GUI functions. def update_target_text(self): np.set_printoptions(precision=3, suppress=True) text = ('target number = %s\n' % str(self._target_number) + 'actuator name = %s\n' % str(self._actuator_name) + '\ninitial position\n%s' % self.position_to_str(self._initial_position) + '\ntarget position\n%s' % self.position_to_str(self._target_position) + '\ninitial image (left) =\n%s\n' % str(self._initial_image) + '\ntarget image (right) =\n%s\n' % str(self._target_image)) self._target_output.set_text(text) if config['image_on']: self._initial_image_visualizer.update(self._initial_image) self._target_image_visualizer.update(self._target_image) self._image_visualizer.set_initial_image(self._initial_image, alpha=0.3) self._image_visualizer.set_target_image(self._target_image, alpha=0.3) def position_to_str(self, position): np.set_printoptions(precision=3, suppress=True) ja, ee_pos, ee_rot = position return ('joint angles =\n%s\n' % ja + 'end effector positions =\n%s\n' % ee_pos + 'end effector rotations =\n%s\n' % ee_rot) def set_action_status_message(self, action, status, message=None): text = action + ': ' + status if message: text += '\n\n' + message self.set_action_text(text) if status == 'requested': self.set_action_bgcolor('yellow') elif status == 'completed': self.set_action_bgcolor('green') elif status == 'failed': self.set_action_bgcolor('red') def set_action_text(self, text=''): self._action_output.set_text(text) def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) def reload_positions(self): """ Reloads the initial/target positions and images. This is called after the target number of actuator type have changed. """ self._initial_position = load_pose_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial') self._target_position = load_pose_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target') self._initial_image = load_data_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'initial', 'image', default=None) self._target_image = load_data_from_npz(self._target_filename, self._actuator_name, str(self._target_number), 'target', 'image', default=None)
def __init__(self, hyperparams, agent): self._hyperparams = hyperparams self._agent = agent self._log_filename = self._hyperparams['log_filename'] self._target_filename = self._hyperparams['target_filename'] self._num_targets = config['num_targets'] self._actuator_types = config['actuator_types'] self._actuator_names = config['actuator_names'] self._num_actuators = len(self._actuator_types) # Target Setup Status. self._target_number = 0 self._actuator_number = 0 self._actuator_type = self._actuator_types[self._actuator_number] self._actuator_name = self._actuator_names[self._actuator_number] self._initial_position = ('unknown', 'unknown', 'unknown') self._target_position = ('unknown', 'unknown', 'unknown') self._initial_image = None self._target_image = None self._mannequin_mode = False self._mm_process = None # Actions. actions_arr = [ Action('ptn', 'prev_target_number', self.prev_target_number, axis_pos=0), Action('ntn', 'next_target_number', self.next_target_number, axis_pos=1), Action('pat', 'prev_actuator_type', self.prev_actuator_type, axis_pos=2), Action('nat', 'next_actuator_type', self.next_actuator_type, axis_pos=3), Action('sip', 'set_initial_position', self.set_initial_position, axis_pos=4), Action('stp', 'set_target_position', self.set_target_position, axis_pos=5), Action('sii', 'set_initial_image', self.set_initial_image, axis_pos=6), Action('sti', 'set_target_image', self.set_target_image, axis_pos=7), Action('mti', 'move_to_initial', self.move_to_initial, axis_pos=8), Action('mtt', 'move_to_target', self.move_to_target, axis_pos=9), Action('rc', 'relax_controller', self.relax_controller, axis_pos=10), Action('mm', 'mannequin_mode', self.mannequin_mode, axis_pos=11), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(4, 4) self._gs_action_panel = self._gs[0:1, 0:4] if config['image_on']: self._gs_target_output = self._gs[1:3, 0:2] self._gs_initial_image_visualizer = self._gs[3:4, 0:1] self._gs_target_image_visualizer = self._gs[3:4, 1:2] self._gs_action_output = self._gs[1:2, 2:4] self._gs_image_visualizer = self._gs[2:4, 2:4] else: self._gs_target_output = self._gs[1:4, 0:2] self._gs_action_output = self._gs[1:4, 2:4] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 3, 4, actions_arr) self._target_output = Textbox( self._fig, self._gs_target_output, log_filename=self._log_filename, fontsize=config['target_output_fontsize']) self._action_output = Textbox(self._fig, self._gs_action_output) if config['image_on']: self._initial_image_visualizer = ImageVisualizer( self._fig, self._gs_initial_image_visualizer) self._target_image_visualizer = ImageVisualizer( self._fig, self._gs_target_image_visualizer) self._image_visualizer = ImageVisualizer( self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self.reload_positions() self.update_target_text() self.set_action_text('Press an action to begin.') self.set_action_bgcolor('white') self._fig.canvas.draw()
def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config['initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox(self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread(target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start()
class GPSTrainingGUI(object): def __init__(self, hyperparams): self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config['initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', 'stop': 'red', 'reset': 'yellow', 'go': 'green', 'fail': 'magenta', } self._first_update = True # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('go', 'go', self.request_go, axis_pos=2), Action('fail', 'fail', self.request_fail, axis_pos=3), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(16, 8) self._gs_action_panel = self._gs[0:1, 0:8] self._gs_action_output = self._gs[1:2, 0:4] self._gs_status_output = self._gs[2:3, 0:4] self._gs_cost_plotter = self._gs[1:3, 4:8] self._gs_gt_cost_plotter = self._gs[4:6, 4:8] self._gs_algthm_output = self._gs[3:9, 0:4] if config['image_on']: self._gs_traj_visualizer = self._gs[9:16, 0:4] self._gs_image_visualizer = self._gs[9:16, 4:8] else: self._gs_traj_visualizer = self._gs[9:16, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 1, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox(self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._gt_cost_plotter = MeanPlotter(self._fig, self._gs_gt_cost_plotter, color='red', label='ground truth cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) if config['image_on']: self._image_visualizer = ImageVisualizer(self._fig, self._gs_image_visualizer, cropsize=config['image_size'], rostopic=config['image_topic'], show_overlay_buttons=True) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread(target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start() # GPS Training functions def request_stop(self, event=None): self.request_mode('stop') def request_reset(self, event=None): self.request_mode('reset') def request_go(self, event=None): self.request_mode('go') def request_fail(self, event=None): self.request_mode('fail') def request_mode(self, request): """ Sets the request mode (stop, reset, go, fail). The request is read by gps_main before sampling, and the appropriate action is taken. """ self.mode = 'request' self.request = request self.set_action_text(self.request + ' requested') self.set_action_bgcolor(self._colors[self.request], alpha=0.2) def process_mode(self): """ Completes the current request, after it is first read by gps_main. Displays visual confirmation that the request was processed, displays any error messages, and then switches into mode 'run' or 'wait'. """ self.mode = 'process' self.set_action_text(self.request + ' processed') self.set_action_bgcolor(self._colors[self.request], alpha=1.0) if self.err_msg: self.set_action_text(self.request + ' processed' + '\nERROR: ' + self.err_msg) self.err_msg = None time.sleep(1.0) else: time.sleep(0.5) if self.request in ('stop', 'reset', 'fail'): self.wait_mode() elif self.request == 'go': self.run_mode() self.request = None def wait_mode(self): self.mode = 'wait' self.set_action_text('waiting') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def run_mode(self): self.mode = 'run' self.set_action_text('running') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def end_mode(self): self.mode = 'end' self.set_action_text('ended') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def estop(self, event=None): self.set_action_text('estop: NOT IMPLEMENTED') # GUI functions def set_action_text(self, text): self._action_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_status_text(self, text): self._status_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def set_output_text(self, text): self._algthm_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def append_output_text(self, text): self._algthm_output.append_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels self._gt_cost_plotter.draw_ticklabels() def start_display_calculating(self): self._calculating_run.set() def stop_display_calculating(self): self._calculating_run.clear() def set_image_overlays(self, condition): """ Sets up the image visualizer with what images to overlay if "overlay_initial_image" or "overlay_target_image" is pressed. """ if not config['image_on'] or not self._target_filename: return initial_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'initial', 'image', default=None) target_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'target', 'image', default=None) self._image_visualizer.set_initial_image(initial_image, alpha=config['image_overlay_alpha']) self._image_visualizer.set_target_image(target_image, alpha=config['image_overlay_alpha']) # Iteration update functions def update(self, itr, algorithm, agent, traj_sample_lists, pol_sample_lists): """ After each iteration, update the iteration data output, the cost plot, and the 3D trajectory visualizations (if end effector points exist). """ if self._first_update: policy_titles = pol_sample_lists != None self._output_column_titles(algorithm, policy_titles) self._first_update = False costs = [np.mean(np.sum(algorithm.prev[m].cs, axis=1)) for m in range(algorithm.M)] if algorithm._hyperparams['ioc']: gt_costs = [np.mean(np.sum(algorithm.prev[m].cgt, axis=1)) for m in range(algorithm.M)] self._update_iteration_data(itr, algorithm, gt_costs, pol_sample_lists) self._gt_cost_plotter.update(gt_costs, t=itr) else: self._update_iteration_data(itr, algorithm, costs, pol_sample_lists) self._cost_plotter.update(costs, t=itr) if END_EFFECTOR_POINTS in agent.x_data_types: self._update_trajectory_visualizations(algorithm, agent, traj_sample_lists, pol_sample_lists) self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend # import pdb; pdb.set_trace() def _output_column_titles(self, algorithm, policy_titles=False): """ Setup iteration data column titles: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ self.set_output_text(self._hyperparams['experiment_name']) if policy_titles: condition_titles = '%3s | %8s %12s' % ('', '', '') itr_data_fields = '%3s | %8s %12s' % ('itr', 'avg_cost', 'avg_pol_cost') else: condition_titles = '%3s | %8s' % ('', '') itr_data_fields = '%3s | %8s' % ('itr', 'avg_cost') for m in range(algorithm.M): condition_titles += ' | %8s %9s %-7d' % ('', 'condition', m) itr_data_fields += ' | %8s %8s %8s' % (' cost ', ' step ', 'entropy ') if algorithm.prev[0].pol_info is not None: condition_titles += ' %8s %8s' % ('', '') itr_data_fields += ' %8s %8s' % ('kl_div_i', 'kl_div_f') if algorithm._hyperparams['ioc'] and not algorithm._hyperparams['learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('kl_div') if algorithm._hyperparams['learning_from_prior']: condition_titles += ' %8s' % ('') itr_data_fields += ' %8s' % ('mean_dist') if policy_titles: condition_titles += ' %8s %8s %8s' % ('', '', '') itr_data_fields += ' %8s %8s %8s' % ('pol_cost', 'kl_div_i', 'kl_div_f') self.append_output_text(condition_titles) self.append_output_text(itr_data_fields) def _update_iteration_data(self, itr, algorithm, costs, pol_sample_lists): """ Update iteration data information: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ avg_cost = np.mean(costs) if pol_sample_lists is not None: pol_costs = [np.mean([np.sum(algorithm.cost[m].eval(s)[0]) \ for s in pol_sample_lists[m]]) \ for m in range(algorithm.M)] itr_data = '%3d | %8.2f %12.2f' % (itr, avg_cost, np.mean(pol_costs)) else: itr_data = '%3d | %8.2f' % (itr, avg_cost) for m in range(algorithm.M): cost = costs[m] step = algorithm.prev[m].step_mult * algorithm.base_kl_step entropy = 2*np.sum(np.log(np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar, axis1=1, axis2=2))) itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy) if algorithm.prev[0].pol_info is not None: kl_div_i = algorithm.prev[m].pol_info.prev_kl[0] kl_div_f = algorithm.prev[m].pol_info.prev_kl[-1] itr_data += ' %8.2f %8.2f' % (kl_div_i, kl_div_f) if algorithm._hyperparams['ioc'] and not algorithm._hyperparams['learning_from_prior']: itr_data += ' %8.2f' % (algorithm.kl_div[itr][m]) if algorithm._hyperparams['learning_from_prior']: itr_data += ' %8.2f' % (algorithm.dists_to_target[itr][m]) if pol_sample_lists is not None: kl_div_i = algorithm.cur[m].pol_info.init_kl.mean() kl_div_f = algorithm.cur[m].pol_info.prev_kl.mean() itr_data += ' %8.2f %8.2f %8.2f' % (pol_costs[m], kl_div_i, kl_div_f) self.append_output_text(itr_data) def _update_trajectory_visualizations(self, algorithm, agent, traj_sample_lists, pol_sample_lists): """ Update 3D trajectory visualizations information: the trajectory samples, policy samples, and linear Gaussian controller means and covariances. """ xlim, ylim, zlim = self._calculate_3d_axis_limits(traj_sample_lists, pol_sample_lists) for m in range(algorithm.M): self._traj_visualizer.clear(m) self._traj_visualizer.set_lim(i=m, xlim=xlim, ylim=ylim, zlim=zlim) self._update_samples_plots(traj_sample_lists, m, 'green', 'Trajectory Samples') self._update_linear_gaussian_controller_plots(algorithm, agent, m) if pol_sample_lists: self._update_samples_plots(pol_sample_lists, m, 'blue', 'Policy Samples') self._traj_visualizer.draw() # this must be called explicitly def _calculate_3d_axis_limits(self, traj_sample_lists, pol_sample_lists): """ Calculate the 3D axis limits shared between trajectory plots, based on the minimum and maximum xyz values across all samples. """ all_eept = np.empty((0, 3)) sample_lists = traj_sample_lists if pol_sample_lists: sample_lists += traj_sample_lists for sample_list in sample_lists: for sample in sample_list.get_samples(): ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] all_eept = np.r_[all_eept, ee_pt_i] min_xyz = np.amin(all_eept, axis=0) max_xyz = np.amax(all_eept, axis=0) xlim = buffered_axis_limits(min_xyz[0], max_xyz[0], buffer_factor=1.25) ylim = buffered_axis_limits(min_xyz[1], max_xyz[1], buffer_factor=1.25) zlim = buffered_axis_limits(min_xyz[2], max_xyz[2], buffer_factor=1.25) return xlim, ylim, zlim def _update_linear_gaussian_controller_plots(self, algorithm, agent, m): """ Update the linear Guassian controller plots with iteration data, for the mean and covariances of the end effector points. """ # Calculate mean and covariance for end effector points eept_idx = agent.get_idx_x(END_EFFECTOR_POINTS) start, end = eept_idx[0], eept_idx[-1] mu, sigma = algorithm.traj_opt.forward(algorithm.prev[m].traj_distr, algorithm.prev[m].traj_info) mu_eept, sigma_eept = mu[:, start:end+1], sigma[:, start:end+1, start:end+1] # Linear Gaussian Controller Distributions (Red) for i in range(mu_eept.shape[1]/3): mu, sigma = mu_eept[:, 3*i+0:3*i+3], sigma_eept[:, 3*i+0:3*i+3, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_gaussian(i=m, mu=mu, sigma=sigma, edges=100, linestyle='-', linewidth=1.0, color='red', alpha=0.15, label='LG Controller Distributions') # Linear Gaussian Controller Means (Dark Red) for i in range(mu_eept.shape[1]/3): mu = mu_eept[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(i=m, points=mu, linestyle='None', marker='x', markersize=5.0, markeredgewidth=1.0, color=(0.5, 0, 0), alpha=1.0, label='LG Controller Means') def _update_samples_plots(self, sample_lists, m, color, label): """ Update the samples plots with iteration data, for the trajectory samples and the policy samples. """ samples = sample_lists[m].get_samples() for sample in samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color=color, label=label) def save_figure(self, filename): self._fig.savefig(filename)
def __init__(self, hyperparams, agent): self._agent = agent self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config['initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', } self._actuator_types = config['actuator_types'] self._actuator_names = config['actuator_names'] self._first_update = True self._actuator_number = 0 self._actuator_type = self._actuator_types[self._actuator_number] self._initial_position = ('unknown', 'unknown', 'unknown') self._target_position = ('unknown', 'unknown', 'unknown') # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('GCM go', 'go', self.request_go, axis_pos=2), Action('transfer learning', 'transfer_learning', self.request_tl, axis_pos=3), Action('set initial state', 'initstate', self.request_init_state, axis_pos=4), Action('set goal state', 'goalstate', self.request_goal_state, axis_pos=5), Action('test transfer learning', 'test_tl', self.request_test_tl, axis_pos=6), Action('generalize', 'generalize', self.request_generalize, axis_pos=7), Action('mti', 'move_to_initial', self.move_to_initial, axis_pos=8), Action('mtt', 'move_to_target', self.move_to_target, axis_pos=9), Action('rc', 'relax_controller', self.relax_controller, axis_pos=10), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(18, 8) self._gs_action_panel = self._gs[0:4, 0:8] self._gs_action_output = self._gs[4:5, 0:4] self._gs_status_output = self._gs[5:6, 0:4] self._gs_cost_plotter = self._gs[4:10, 4:8] self._gs_algthm_output = self._gs[6:10, 0:4] self._gs_traj_visualizer = self._gs[10:18, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 3, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox(self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread(target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start()
class TransferLearningGUI(object): def __init__(self, hyperparams, agent): self._agent = agent self._hyperparams = hyperparams self._log_filename = self._hyperparams['log_filename'] if 'target_filename' in self._hyperparams: self._target_filename = self._hyperparams['target_filename'] else: self._target_filename = None # GPS Training Status. self.mode = config['initial_mode'] # Modes: run, wait, end, request, process. self.request = None # Requests: stop, reset, go, fail, None. self.err_msg = None self._colors = { 'run': 'cyan', 'wait': 'orange', 'end': 'red', } self._actuator_types = config['actuator_types'] self._actuator_names = config['actuator_names'] self._first_update = True self._actuator_number = 0 self._actuator_type = self._actuator_types[self._actuator_number] self._initial_position = ('unknown', 'unknown', 'unknown') self._target_position = ('unknown', 'unknown', 'unknown') # Actions. actions_arr = [ Action('stop', 'stop', self.request_stop, axis_pos=0), Action('reset', 'reset', self.request_reset, axis_pos=1), Action('GCM go', 'go', self.request_go, axis_pos=2), Action('transfer learning', 'transfer_learning', self.request_tl, axis_pos=3), Action('set initial state', 'initstate', self.request_init_state, axis_pos=4), Action('set goal state', 'goalstate', self.request_goal_state, axis_pos=5), Action('test transfer learning', 'test_tl', self.request_test_tl, axis_pos=6), Action('generalize', 'generalize', self.request_generalize, axis_pos=7), Action('mti', 'move_to_initial', self.move_to_initial, axis_pos=8), Action('mtt', 'move_to_target', self.move_to_target, axis_pos=9), Action('rc', 'relax_controller', self.relax_controller, axis_pos=10), ] # Setup figure. plt.ion() plt.rcParams['toolbar'] = 'None' for key in plt.rcParams: if key.startswith('keymap.'): plt.rcParams[key] = '' self._fig = plt.figure(figsize=config['figsize']) self._fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0, hspace=0) # Assign GUI component locations. self._gs = gridspec.GridSpec(18, 8) self._gs_action_panel = self._gs[0:4, 0:8] self._gs_action_output = self._gs[4:5, 0:4] self._gs_status_output = self._gs[5:6, 0:4] self._gs_cost_plotter = self._gs[4:10, 4:8] self._gs_algthm_output = self._gs[6:10, 0:4] self._gs_traj_visualizer = self._gs[10:18, 0:8] # Create GUI components. self._action_panel = ActionPanel(self._fig, self._gs_action_panel, 3, 4, actions_arr) self._action_output = Textbox(self._fig, self._gs_action_output, border_on=True) self._status_output = Textbox(self._fig, self._gs_status_output, border_on=False) self._algthm_output = Textbox(self._fig, self._gs_algthm_output, max_display_size=config['algthm_output_max_display_size'], log_filename=self._log_filename, fontsize=config['algthm_output_fontsize'], font_family='monospace') self._cost_plotter = MeanPlotter(self._fig, self._gs_cost_plotter, color='blue', label='mean cost') self._traj_visualizer = Plotter3D(self._fig, self._gs_traj_visualizer, num_plots=self._hyperparams['conditions']) # Setup GUI components. self._algthm_output.log_text('\n') self.set_output_text(self._hyperparams['info']) if config['initial_mode'] == 'run': self.run_mode() else: self.wait_mode() # Setup 3D Trajectory Visualizer plot titles and legends for m in range(self._hyperparams['conditions']): self._traj_visualizer.set_title(m, 'Condition %d' % (m)) self._traj_visualizer.add_legend(linestyle='-', marker='None', color='green', label='Trajectory Samples') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='blue', label='Policy Samples') self._traj_visualizer.add_legend(linestyle='None', marker='x', color=(0.5, 0, 0), label='LG Controller Means') self._traj_visualizer.add_legend(linestyle='-', marker='None', color='red', label='LG Controller Distributions') self._fig.canvas.draw() # Display calculating thread def display_calculating(delay, run_event): while True: if not run_event.is_set(): run_event.wait() if run_event.is_set(): self.set_status_text('Calculating.') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating..') time.sleep(delay) if run_event.is_set(): self.set_status_text('Calculating...') time.sleep(delay) self._calculating_run = threading.Event() self._calculating_thread = threading.Thread(target=display_calculating, args=(1, self._calculating_run)) self._calculating_thread.daemon = True self._calculating_thread.start() # GPS Training functions def request_stop(self, event=None): self.request_mode('stop') self._agent.sample( self._agent.policy, self._agent.condition, verbose=None, save=False, noisy=False, use_TfController=True, timeout=0, reset=False) def request_reset(self, event=None): self.request_mode('reset') self._agent.reset(self._agent.condition) def request_go(self, event=None): self.request_mode('go') def request_tl(self, event=None): self.request_mode('transfer_learning') def request_test_tl(self, event=None): self.request_mode('test_tl') def request_generalize(self, event=None): self.request_mode('generalize') def move_to_initial(self, event=None): ja = self._initial_position[0] self.set_action_status_message('move_to_initial', 'requested') self._agent.reset_arm(self._actuator_type, JOINT_SPACE, ja.T) self.set_action_status_message('move_to_initial', 'completed', message='initial position: %s' % str(ja)) def move_to_target(self, event=None): ja = self._target_position[0] self.set_action_status_message('move_to_target', 'requested') self._agent.reset_arm(self._actuator_type, JOINT_SPACE, ja.T) self.set_action_status_message('move_to_target', 'completed', message='target position: %s' % str(ja)) def relax_controller(self, event=None): self.set_action_status_message('relax_controller', 'requested') self._agent.relax_arm(self._actuator_type) self.set_action_status_message('relax_controller', 'completed', message='actuator name: %s' % self._actuator_name) def request_goal_state(self, event=None): self.request_mode('goalstate') sample = self._agent.get_data(arm=self._actuator_type) ja = sample.get(JOINT_ANGLES) ee_pos = sample.get(END_EFFECTOR_POSITIONS) ee_rot = sample.get(END_EFFECTOR_ROTATIONS) ee_tgt = np.ndarray.flatten( get_ee_points(self._agent._hyperparams['ee_points'], ee_pos, ee_rot).T ) self._agent._hyperparams['ee_points_tgt'] = [ee_tgt] self._target_position = (ja, ee_pos, ee_rot) self._agent._target_ja = [ja] def request_init_state(self, event=None): self.request_mode('initstate') sample = self._agent.get_data(arm=self._actuator_type) ja = sample.get(JOINT_ANGLES) ee_pos = sample.get(END_EFFECTOR_POSITIONS) ee_rot = sample.get(END_EFFECTOR_ROTATIONS) self._initial_position = (ja, ee_pos, ee_rot) reset_condition = { TRIAL_ARM: { 'mode': JOINT_SPACE, 'data': ja[0], }, } res_cons = [] res_cons.append(reset_condition) self._agent._initial_ja = [ja] def request_mode(self, request): """ Sets the request mode (stop, reset, go, fail). The request is read by gps_main before sampling, and the appropriate action is taken. """ self.mode = 'request' self.request = request self.set_action_text(self.request + ' requested') #self.set_action_bgcolor(self._colors[self.request], alpha=0.2) def process_mode(self): """ Completes the current request, after it is first read by gps_main. Displays visual confirmation that the request was processed, displays any error messages, and then switches into mode 'run' or 'wait'. """ self.mode = 'process' self.set_action_text(self.request + ' processed') self.set_action_bgcolor(self._colors[self.request], alpha=1.0) if self.err_msg: self.set_action_text(self.request + ' processed' + '\nERROR: ' + self.err_msg) self.err_msg = None time.sleep(1.0) else: time.sleep(0.5) if self.request in ('stop', 'reset', 'fail'): self.wait_mode() elif self.request == 'go': self.run_mode() self.request = None def wait_mode(self): self.mode = 'wait' self.set_action_text('waiting') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def run_mode(self): self.mode = 'run' self.set_action_text('running') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def end_mode(self): self.mode = 'end' self.set_action_text('ended') self.set_action_bgcolor(self._colors[self.mode], alpha=1.0) def estop(self, event=None): self.set_action_text('estop: NOT IMPLEMENTED') # GUI functions def set_action_text(self, text): self._action_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_action_bgcolor(self, color, alpha=1.0): self._action_output.set_bgcolor(color, alpha) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_status_text(self, text): self._status_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def set_output_text(self, text): self._algthm_output.set_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def append_output_text(self, text): self._algthm_output.append_text(text) self._cost_plotter.draw_ticklabels() # redraw overflow ticklabels def start_display_calculating(self): self._calculating_run.set() def stop_display_calculating(self): self._calculating_run.clear() def set_action_status_message(self, action, status, message=None): text = action + ': ' + status if message: text += '\n\n' + message self.set_action_text(text) if status == 'requested': self.set_action_bgcolor('yellow') elif status == 'completed': self.set_action_bgcolor('green') elif status == 'failed': self.set_action_bgcolor('red') def set_image_overlays(self, condition): """ Sets up the image visualizer with what images to overlay if "overlay_initial_image" or "overlay_target_image" is pressed. """ if not config['image_on'] or not self._target_filename: return initial_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'initial', 'image', default=None) target_image = load_data_from_npz(self._target_filename, config['image_overlay_actuator'], str(condition), 'target', 'image', default=None) self._image_visualizer.set_initial_image(initial_image, alpha=config['image_overlay_alpha']) self._image_visualizer.set_target_image(target_image, alpha=config['image_overlay_alpha']) # Iteration update functions def update(self, itr, algorithm, agent, traj_sample_lists, pol_sample_lists): """ After each iteration, update the iteration data output, the cost plot, and the 3D trajectory visualizations (if end effector points exist). """ if self._first_update: self._output_column_titles(algorithm) self._first_update = False costs = [np.mean(np.sum(algorithm.prev[m].cs, axis=1)) for m in range(algorithm.M)] self._update_iteration_data(itr, algorithm, costs, pol_sample_lists) self._cost_plotter.update(costs, t=itr) if END_EFFECTOR_POINTS in agent.x_data_types: self._update_trajectory_visualizations(algorithm, agent, traj_sample_lists, pol_sample_lists) self._fig.canvas.draw() self._fig.canvas.flush_events() # Fixes bug in Qt4Agg backend def _output_column_titles(self, algorithm, policy_titles=False): """ Setup iteration data column titles: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ self.set_output_text(self._hyperparams['experiment_name']) condition_titles = '%3s | %8s' % ('', '') itr_data_fields = '%3s | %8s' % ('itr', 'avg_cost') for m in range(algorithm.M): condition_titles += ' | %8s %9s %-7d' % ('', 'condition', m) itr_data_fields += ' | %8s %8s %8s' % (' cost ', ' step ', 'entropy ') self.append_output_text(condition_titles) self.append_output_text(itr_data_fields) def _update_iteration_data(self, itr, algorithm, costs, pol_sample_lists): """ Update iteration data information: iteration, average cost, and for each condition the mean cost over samples, step size, linear Guassian controller entropies, and initial/final KL divergences for BADMM. """ avg_cost = np.mean(costs) if pol_sample_lists is not None: test_idx = algorithm._hyperparams['test_conditions'] print("tet_idx: ", test_idx) # pol_sample_lists is a list of singletons samples = [sl[0] for sl in pol_sample_lists] print("samples len: ", len(samples)) print("algorithm.costs: ", len(algorithm.cost)) print("costs: ", algorithm.cost[-1].eval(samples[0])) #print("algorithm.cost.eval: ", len(algorithm.cost[-1].eval)) pol_costs = [np.sum(algorithm.cost[idx].eval(s)[0]) for s, idx in zip(samples, test_idx)] itr_data = '%3d | %8.2f %12.2f' % (itr, avg_cost, np.mean(pol_costs)) else: itr_data = '%3d | %8.2f' % (itr, avg_cost) for m in range(algorithm.M): cost = costs[m] step = algorithm.prev[m].step_mult * algorithm.base_kl_step entropy = 2*np.sum(np.log(np.diagonal(algorithm.prev[m].traj_distr.chol_pol_covar, axis1=1, axis2=2))) itr_data += ' | %8.2f %8.2f %8.2f' % (cost, step, entropy) self.append_output_text(itr_data) def _update_trajectory_visualizations(self, algorithm, agent, traj_sample_lists, pol_sample_lists): """ Update 3D trajectory visualizations information: the trajectory samples, policy samples, and linear Gaussian controller means and covariances. """ xlim, ylim, zlim = self._calculate_3d_axis_limits(traj_sample_lists, pol_sample_lists) for m in range(algorithm.M): self._traj_visualizer.clear(m) self._traj_visualizer.set_lim(i=m, xlim=xlim, ylim=ylim, zlim=zlim) self._update_linear_gaussian_controller_plots(algorithm, agent, m) self._update_samples_plots(traj_sample_lists, m, 'green', 'Trajectory Samples') if pol_sample_lists: self._update_samples_plots(pol_sample_lists, m, 'blue', 'Policy Samples') self._traj_visualizer.draw() # this must be called explicitly def _calculate_3d_axis_limits(self, traj_sample_lists, pol_sample_lists): """ Calculate the 3D axis limits shared between trajectory plots, based on the minimum and maximum xyz values across all samples. """ all_eept = np.empty((0, 3)) sample_lists = traj_sample_lists if pol_sample_lists: sample_lists += traj_sample_lists for sample_list in sample_lists: for sample in sample_list.get_samples(): ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] all_eept = np.r_[all_eept, ee_pt_i] min_xyz = np.amin(all_eept, axis=0) max_xyz = np.amax(all_eept, axis=0) xlim = buffered_axis_limits(min_xyz[0], max_xyz[0], buffer_factor=1.25) ylim = buffered_axis_limits(min_xyz[1], max_xyz[1], buffer_factor=1.25) zlim = buffered_axis_limits(min_xyz[2], max_xyz[2], buffer_factor=1.25) return xlim, ylim, zlim def _update_linear_gaussian_controller_plots(self, algorithm, agent, m): """ Update the linear Guassian controller plots with iteration data, for the mean and covariances of the end effector points. """ # Calculate mean and covariance for end effector points eept_idx = agent.get_idx_x(END_EFFECTOR_POINTS) start, end = eept_idx[0], eept_idx[-1] mu, sigma = algorithm.forward(algorithm.prev[m].traj_distr, algorithm.prev[m].traj_info) mu_eept, sigma_eept = mu[:, start:end+1], sigma[:, start:end+1, start:end+1] # Linear Gaussian Controller Distributions (Red) for i in range(mu_eept.shape[1]/3): mu, sigma = mu_eept[:, 3*i+0:3*i+3], sigma_eept[:, 3*i+0:3*i+3, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_gaussian(i=m, mu=mu, sigma=sigma, edges=100, linestyle='-', linewidth=1.0, color='red', alpha=0.15, label='LG Controller Distributions') # Linear Gaussian Controller Means (Dark Red) for i in range(mu_eept.shape[1]/3): mu = mu_eept[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(i=m, points=mu, linestyle='None', marker='x', markersize=5.0, markeredgewidth=1.0, color=(0.5, 0, 0), alpha=1.0, label='LG Controller Means') def _update_samples_plots(self, sample_lists, m, color, label): """ Update the samples plots with iteration data, for the trajectory samples and the policy samples. """ samples = sample_lists[m].get_samples() for sample in samples: ee_pt = sample.get(END_EFFECTOR_POINTS) for i in range(ee_pt.shape[1]/3): ee_pt_i = ee_pt[:, 3*i+0:3*i+3] self._traj_visualizer.plot_3d_points(m, ee_pt_i, color=color, label=label) def save_figure(self, filename): self._fig.savefig(filename)
Action('plus1', 'plus1', plus_1, axis_pos=0, keyboard_binding='1', ps3_binding=None), Action('plus2', 'plus2', plus_2, axis_pos=1, keyboard_binding='2', ps3_binding=None), Action('print', 'print', mult_4, axis_pos=2, keyboard_binding='4', ps3_binding=None), ] action_panel = ActionPanel(fig, gs[0], 3, 1, actions_arr) # Textbox def demo_textbox(): max_i = 20 for i in range(max_i): textbox.append_text(str(i)) c = 0.5 + 0.5*i/max_i textbox.set_bgcolor((c, c, c)) time.sleep(1) textbox = Textbox(fig, gs[1], max_display_size=10, log_filename=None) run_demo(demo_textbox) # Image Visualizer def demo_image_visualizer(): im = np.zeros((5, 5, 3)) while True: i = random.randint(0, im.shape[0] - 1) j = random.randint(0, im.shape[1] - 1) k = random.randint(0, im.shape[2] - 1) im[i, j, k] = (im[i, j, k] + random.randint(0, 255)) % 256 image_visualizer.update(im) time.sleep(5e-3) image_visualizer = ImageVisualizer(fig, gs[2], cropsize=(3, 3)) run_demo(demo_image_visualizer)