def draw_policy(dic): fig, ax = plt.subplots(figsize=(10, 10)) ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) num_rows, num_cols = (5, 5) per_height, per_width = 1.0 / num_rows, 1.0 / num_cols for (i, j), val in dic.items(): #print(i,j,val) val = describe(val) text = '' for k in range(len(val)): text += val[k] + ',' tb.add_cell(i, j, text=text, width=per_width, height=per_height, loc='center') tb.set_fontsize(28) ax.add_table(tb)
def draw_image(image): fig, ax = plt.subplots() ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) nrows, ncols = image.shape width, height = 1.0 / ncols, 1.0 / nrows for (i, j), val in np.ndenumerate(image): idx = [j % 2, (j + 1) % 2][i % 2] color = 'white' tb.add_cell(i, j, width, height, text=val, loc='center', facecolor=color) for i, label in enumerate(range(len(image))): tb.add_cell(i, -1, width, height / 2, text=label + 1, loc='center', edgecolor='none', facecolor='none') for j, label in enumerate(range(len(image))): tb.add_cell(-1, j, width, height / 2, text=label + 1, loc='center', edgecolor='none', facecolor='none') ax.add_table(tb)
def draw_image(image): fig, ax = plt.subplots() ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) nrows, ncols = image.shape width, height = 1.0 / ncols, 1.0 / nrows # Add form for (i, j), val in np.ndenumerate(image): tb.add_cell(i, j, width, height, text=val, loc='center', facecolor='white') # row label for i, label in enumerate(range(len(image))): tb.add_cell(i, -1, width, height, text=label, loc='right', edgecolor='none', facecolor='none') # Column Label for j, label in enumerate(range(len(image))): tb.add_cell(grid_world_h, j, width, height / 2, text=label, loc='center', edgecolor='none', facecolor='none') ax.add_table(tb)
def draw_image(image): fig, ax = plt.subplots() ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) nrows, ncols = image.shape width, height = 1.0 / ncols, 1.0 / nrows # Add cells for (i, j), val in np.ndenumerate(image): tb.add_cell(i, j, width, height, text=val, loc='center', facecolor='white') # Row and column labels... for i in range(len(image)): tb.add_cell(i, -1, width, height, text=i + 1, loc='right', edgecolor='none', facecolor='none') tb.add_cell(-1, i, width, height / 2, text=i + 1, loc='center', edgecolor='none', facecolor='none') ax.add_table(tb)
def display(self, agent): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax.set_axis_off() tb = Table(ax) nrows, ncols = self.world.shape width, height = 1.0 / ncols, 1.0 / ncols for (i, j), val in np.ndenumerate(self.world): if (i, j) == (agent.pos_y, agent.pos_x): tb.add_cell(row=i, col=j, width=width, height=height, facecolor=agent.color) if (i, j) == self.starting_point: tb.add_cell(row=i, col=j, width=width, height=height, text="Start", facecolor="b") elif (i, j) == self.ending_point: tb.add_cell(row=i, col=j, width=width, height=height, text="End", facecolor="g") else: tb.add_cell(row=i, col=j, width=width, height=height, facecolor=self.dict_world_to_color[val], loc="center") ax.add_table(tb) plt.title("The cliff gridworld") plt.plot() plt.show()
def draw_lat_matrix(fname,data,title="",lat_type=None,lat_types=None,columns=[],rows=[]): bkg_colors=['#CCFFFF','white'] fmt='{:.3f}' fig,ax=plt.subplots() ax.set_axis_off() tb=Table(ax,bbox=[0,0,1,1]) nrows,ncols=data.shape assert nrows==len(rows) assert ncols==len(columns) width, height = 1.0 / ncols * 2, 1.0 / nrows * 2 for (i,j),val in np.ndenumerate(data): idx = [j % 2, (j + 1) % 2][i % 2] color = bkg_colors[idx] if val: if lat_type: txt=fmt.format(ls.exec_fn(val,lat_type)) else: assert lat_types txt=fmt.format(ls.exec_fn(val,lat_types[i])) else: txt="-" tb.add_cell(i, j, width, height, text=txt, loc='center', facecolor=color) # Row Labels... for i, label in enumerate(rows): tb.add_cell(i, -1, width, height, text=label, loc='right', edgecolor='none', facecolor='none') # Column Labels... for j, label in enumerate(columns): tb.add_cell(-1, j, width, height/2, text=label, loc='center', edgecolor='none', facecolor='none') ax.add_table(tb) if title: ax.set_title("\n".join(wrap(title)), y=1.08) savefig('../figs/' + fname +'.pdf', bbox_inches='tight') plt.close()
def draw_policy(optimal_values): fig, ax = plt.subplots() ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) nrows, ncols = optimal_values.shape width, height = 1.0 / ncols, 1.0 / nrows # Add cells for (i, j), val in np.ndenumerate(optimal_values): next_vals=[] for action in ACTIONS: next_state, _ = step([i, j], action) next_vals.append(optimal_values[next_state[0],next_state[1]]) best_actions=np.where(next_vals == np.max(next_vals))[0] val='' for ba in best_actions: val+=ACTIONS_FIGS[ba] # add state labels if [i, j] == A_POS: val = str(val) + " (A)" if [i, j] == A_PRIME_POS: val = str(val) + " (A')" tb.add_cell(i, j, width, height, text=val, loc='center', facecolor='white') # Row and column labels... for i in range(len(optimal_values)): tb.add_cell(i, -1, width, height, text=i+1, loc='right', edgecolor='none', facecolor='none') tb.add_cell(-1, i, width, height/2, text=i+1, loc='center', edgecolor='none', facecolor='none') ax.add_table(tb)
def plot_value(self, ax: plt.Axes = None): if ax is None: _, ax = plt.subplots() rounded_value = np.round(self.value, decimals=2) ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) nrows, ncols = rounded_value.shape width, height = 1.0 / ncols, 1.0 / nrows for (i, j), val in np.ndenumerate(rounded_value): tb.add_cell(i, j, width, height, text=val, loc="center", facecolor="white") for i in range(len(rounded_value)): tb.add_cell( i, -1, width, height, text=i + 1, loc="right", edgecolor="none", facecolor="none", ) tb.add_cell( -1, i, width, height / 2, text=i + 1, loc="center", edgecolor="none", facecolor="none", ) ax.add_table(tb)
def __make_table(self): ''' Transforms the filter model into a table. This code is based on the code in matplotlib.table :return: the table. ''' if len(self.__filter_model) > 0 and self.__magnitude_model is not None: fc = self.__magnitude_model.limits.axes_1.get_facecolor() cell_kwargs = {} if self.__filter_axes is not None: table_axes = self.__filter_axes table_loc = {'loc': 'center'} else: table_axes = self.__magnitude_model.limits.axes_1 table_loc = { 'bbox': (self.x0.value(), self.y0.value(), self.x1.value() - self.x0.value(), self.y1.value() - self.y0.value()) } # this is some hackery around the way the matplotlib table works # multiplier = 1.2 * 1.85 if not self.__first_create else 1.2 multiplier = self.filterRowHeightMultiplier.value() self.__first_create = False row_height = (self.tableFontSize.value() / 72.0 * self.preview.canvas.figure.dpi / table_axes.bbox.height * multiplier) cell_kwargs['facecolor'] = fc table = Table(table_axes, **table_loc) table.set_zorder(1000) self.__add_filters_to_table(table, row_height, cell_kwargs) table.auto_set_font_size(False) table.set_fontsize(self.tableFontSize.value()) return table return None
def draw_grid_world_policy_image(policy, filename, GRID_HEIGHT, GRID_WIDTH, ACTION_SYMBOLS, TERMINAL_STATES=None): action_str_values = [] for i in range(GRID_HEIGHT): action_str_values.append([]) for j in range(GRID_WIDTH): if TERMINAL_STATES and (i, j) in TERMINAL_STATES: continue str_values = [] for action in policy[(i, j)]: str_values.append("{0} ({1})".format(ACTION_SYMBOLS[action], action)) action_str_values[i].append("\n".join(str_values)) # 축 표시 제거, 크기 조절 등 이미지 그리기 이전 설정 작업 fig, ax = plt.subplots() ax.set_axis_off() table = Table(ax, bbox=[0, 0, 1, 1]) nrows, ncols = GRID_HEIGHT, GRID_WIDTH width, height = 1.0 / ncols, 1.0 / nrows # 렌더링 할 이미지에 표 셀과 해당 값 추가 for i in range(GRID_HEIGHT): for j in range(GRID_WIDTH): if TERMINAL_STATES and (i, j) in TERMINAL_STATES: continue table.add_cell(i, j, width, height, text=action_str_values[i][j], loc='center', facecolor='white') # 행, 열 라벨 추가 for i in range(len(action_str_values)): table.add_cell(i, -1, width, height, text=i + 1, loc='right', edgecolor='none', facecolor='none') table.add_cell(-1, i, width, height / 2, text=i + 1, loc='center', edgecolor='none', facecolor='none') for key, cell in table.get_celld().items(): cell.get_text().set_fontsize(10) ax.add_table(table) plt.savefig(filename) plt.close()
def draw_image(policy, name): fig, ax = plt.subplots() ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) ncols = WORLD_X nrows = WORLD_Y width, height = 1.0 / ncols, 1.0 / nrows # Add cells for i in range(WORLD_Y): for j in range(WORLD_X): color = 'white' #bank if (j, i) in bank_pos: color = '#9bb4db' if policy[i][j][1][2] == -1: text = "Police" elif policy[i][j][1][2] == 0: text = "Left" elif policy[i][j][1][2] == 1: text = "Up" elif policy[i][j][1][2] == 2: text = "Right" elif policy[i][j][1][2] == 3: text = "Down" elif policy[i][j][1][2] == 4: text = "Stay" tb.add_cell(i, j, width, height, text=text, loc='center', edgecolor='#63b1f2', facecolor=color) # Row Labels... for i, label in enumerate(range(WORLD_Y)): tb.add_cell(i, -1, width, height, text=label + 1, loc='right', edgecolor='none', facecolor='none') # Column Labels... for j, label in enumerate(range(WORLD_X)): tb.add_cell(-1, j, width, height / 2, text=label + 1, loc='center', edgecolor='none', facecolor='none') ax.add_table(tb) #Limits borders = [[(0, 0), (1, 0)], [(0, 0), (0, 1)], [(1, 1), (1, 0)], [(1, 1), (0, 1)]] lc = mc.LineCollection(borders, colors='k', linewidths=4) ax.add_collection(lc) plt.savefig('lambda' + name + '.png') plt.close()
def checkerboard_plot(ary, cell_colors=('white', 'black'), font_colors=('black', 'white'), fmt='%.1f', figsize=None, row_labels=None, col_labels=None, fontsize=None): """ Plot a checkerboard table / heatmap via matplotlib. Parameters ----------- ary : array-like, shape = [n, m] A 2D Nnumpy array. cell_colors : tuple or list (default: ('white', 'black')) Tuple or list containing the two colors of the checkerboard pattern. font_colors : tuple or list (default: ('black', 'white')) Font colors corresponding to the cell colors. figsize : tuple (default: (2.5, 2.5)) Height and width of the figure fmt : str (default: '%.1f') Python string formatter for cell values. The default '%.1f' results in floats with 1 digit after the decimal point. Use '%d' to show numbers as integers. row_labels : list (default: None) List of the row labels. Uses the array row indices 0 to n by default. col_labels : list (default: None) List of the column labels. Uses the array column indices 0 to m by default. fontsize : int (default: None) Specifies the font size of the checkerboard table. Uses matplotlib's default if None. Returns ----------- fig : matplotlib Figure object. Examples ----------- For usage examples, please see http://rasbt.github.io/mlxtend/user_guide/plotting/checkerboard_plot/ """ fig, ax = subplots(figsize=figsize) ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) n_rows, n_cols = ary.shape if row_labels is None: row_labels = np.arange(n_rows) if col_labels is None: col_labels = np.arange(n_cols) width, height = 1.0 / n_cols, 1.0 / n_rows for (row_idx, col_idx), cell_val in np.ndenumerate(ary): idx = (col_idx + row_idx) % 2 tb.add_cell(row_idx, col_idx, width, height, text=fmt % cell_val, loc='center', facecolor=cell_colors[idx]) for row_idx, label in enumerate(row_labels): tb.add_cell(row_idx, -1, width, height, text=label, loc='right', edgecolor='none', facecolor='none') for col_idx, label in enumerate(col_labels): tb.add_cell(-1, col_idx, width, height / 2., text=label, loc='center', edgecolor='none', facecolor='none') for (row_idx, col_idx), cell_val in np.ndenumerate(ary): idx = (col_idx + row_idx) % 2 tb._cells[(row_idx, col_idx)]._text.set_color(font_colors[idx]) ax.add_table(tb) tb.set_fontsize(fontsize) return fig
def create_report(**kwargs): draw_class_images(kwargs['classimg'], kwargs['exp_home']) reportfile = os.path.join(kwargs['exp_home'], 'report.pdf') colors = kwargs['colors'] regions, areas = imgs_to_plot(kwargs['colorized']) nreg = len(regions) with PdfPages(reportfile) as pdf: if nreg == 1: ax = plt.figure() plt.imshow(regions[0]) # ax.set_xlabel([]) # ax.set_ylabel([]) plt.title('Region 1, Area = {}'.format(areas[0])) pdf.savefig(dpi=300) plt.close() else: for k, (reg, area) in enumerate(zip(regions, areas)): ax = plt.figure() plt.imshow(reg) # ax.set_xlabel([]) # ax.set_ylabel([]) plt.title('Region {}, Area = {}'.format(k, area)) pdf.savefig(dpi=300) plt.close() plt.rc('text', usetex=True) fig, ax = plt.subplots(1, 1) # ax.set_xtick([]) # ax.set_ytick([]) bbox_colors = [0.1, 0.1, 0.2, 0.8] tb = Table(ax, bbox=bbox_colors) # tb = Table(ax) tb.add_cell( 1, 1, 0.1, 0.25, text='BG', loc='center', facecolor=colors[0, :] / 255.0) tb.add_cell( 2, 1, 0.1, 0.25, text='Low Grade', loc='center', facecolor=colors[1, :] / 255.0) tb.add_cell( 3, 1, 0.1, 0.25, text='High Grade', loc='center', facecolor=colors[2, :] / 255.0) tb.add_cell( 4, 1, 0.1, 0.25, text='BN', loc='center', facecolor=colors[3, :] / 255.0) tb.add_cell( 5, 1, 0.1, 0.25, text='ST', loc='center', facecolor=colors[4, :] / 255.0) ax.add_table(tb) header, data = build_stat_string( filename=kwargs['filename'], time_elapsed=kwargs['time_elapsed'], stats=kwargs['stats'], process_map=kwargs['process_map']) bbox_info = [0.5, 0.1, 0.4, 0.8] tb = Table(ax, bbox=bbox_info) nrow = float(len(data)) for k, (h, d) in enumerate(zip(header, data)): tb.add_cell(k, -1, 0.1, 1 / nrow, text=h) tb.add_cell(k, 1, 0.3, 1 / nrow, text=d) # tab = ax.table(cellText=data, rowLabels=header, cellLoc='center', # bbox=bbox_info) # tab = ax.table(cellText=data, rowLabels=header, loc='center', cellLoc='center') ax.add_table(tb) pdf.savefig() plt.close()
def plot_cliff_world(ax): ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.set_xticks([]) ax.set_yticks([]) H = 0.5 row, col = CliffWorld.H, CliffWorld.W Δh, Δw = 1 / row, 1 / col tb = Table(ax, bbox=[0, H, 1, 1 - H]) for i in range(row - 1): tb.add_cell(i, 0, Δw, Δh, text='').set_lw(.5) tb.add_cell(row - 1, 11, Δw, Δh, text='').set_lw(.5) c = tb.add_cell(row - 1, 1, Δw * 10, Δh, text='The Cliff', loc='center') c.set_lw(.5) c.set_color('#aaaaaa') tb2 = Table(ax, bbox=[0, H, 1, 1 - H]) for i in range(row - 1): for j in range(col): tb2.add_cell(i, j, Δw, Δh, text='').set_lw(.5) i = row - 1 tb2.add_cell(i, 0, Δw, Δh, text='S', loc='center').set_lw(.5) tb2.add_cell(i, 11, Δw, Δh, text='G', loc='center').set_lw(.5) ax.add_table(tb) ax.add_table(tb2) ax.set_title(r'$R=-1$') ax.text(-0.01, H + (1 - H) / row * 3.5, 'safe path', fontsize=12, horizontalalignment='right', verticalalignment='center') ax.text(-0.01, H + (1 - H) / row * .5, 'optimal path', fontsize=12, horizontalalignment='right', verticalalignment='center') p2p = lambda p: (Δw * (p + .5), H) for i, p in enumerate([1, 2, 5, 10]): arrow = patches.FancyArrowPatch( p2p(p), (Δw / 5 * (4 - i), H), connectionstyle="angle3, angleA=45,angleB=-70", arrowstyle="Simple,tail_width=0.5,head_width=4,head_length=6", color="k") ax.add_artist(arrow) ax.text(0.3, H - 0.05, '• • •', fontsize=13, horizontalalignment='center', verticalalignment='center') ax.text(0.64, H - 0.05, '• • •', fontsize=13, horizontalalignment='center', verticalalignment='center') ax.text(0.5, H - 0.2, r'R=-100', fontsize=13, horizontalalignment='center', verticalalignment='center') ps = [[(Δw * .5, H + (1 - H) / row * 0.8), (Δw * .5, H + (1 - H) / row * 3.5), .04], [(Δw * .6, H + (1 - H) / row * 3.5), (1 - Δw * .6, H + (1 - H) / row * 3.5), .03], [(1 - Δw * .5, H + (1 - H) / row * 3.5), (1 - Δw * .5, H + (1 - H) / row * 0.8), .04], [(Δw * .6, H + (1 - H) / row * 1.5), (1 - Δw * .6, H + (1 - H) / row * 1.5), .03]] for p1, p2, d in ps: ax.arrow(*p1, p2[0] - p1[0], p2[1] - p1[1], head_width=.012, head_length=d, length_includes_head=True, color='#9e9e9e')
def __draw_state_transition__(tran_list, state_list, node_map, node_patches, node_text, cur_ax, time_mult=0.1): """ Draw the state transition """ import time from numpy import zeros from matplotlib.pyplot import draw from matplotlib.table import Table num_tran = tran_list.shape[0] evt_id_list = tran_list['e_idx'] evt_t_list = tran_list['e_t'] evt_s_list = tran_list['e_n'] state_e_list = state_list['e_idx'] state_n_list = state_list['n_idx'] state_val_list = state_list['n_val'] len_states = len(state_e_list) x_lim = cur_ax.get_xlim() y_lim = cur_ax.get_ylim() x_range = x_lim[1] - x_lim[0] y_range = y_lim[1] - y_lim[0] col_width = x_range / float(len(node_map)) row_height = y_range * 0.1 / 2.0 state_table = Table(cur_ax, bbox=(0, 0, 1, 0.1)) cell_pos = dict() for idx, key_value in enumerate(node_map.items()): _key = key_value[0] _val = key_value[1] cell_pos[_key] = idx state_table.add_cell(0, idx, col_width, row_height, text=str(_val), loc='center', edgecolor='none', facecolor='none') state_table.add_cell(1, idx, col_width, row_height, text='-', loc='center', edgecolor='none', facecolor='none') cur_ax.add_table(state_table) draw() state_cells = state_table.get_celld() t_delta_last = 0.0 evt_mask = zeros(len(node_patches), dtype=int) patch_list = node_patches.values() last_time = -1.0 new_state_idx = 0 new_state_e = state_e_list[new_state_idx] for idx in range(num_tran): evt_id = evt_id_list[idx] evt_t = evt_t_list[idx] if last_time < evt_t: draw() for pat in patch_list: _alpha = pat.get_alpha() if _alpha > (__disp_defs__.PATCH_MIN_ALPHA + __disp_defs__.PATCH_ALPHA_DECAY): pat.set_alpha(_alpha - __disp_defs__.PATCH_ALPHA_DECAY) t_delta = evt_t - t_delta_last t_delta_last = evt_t time.sleep(t_delta * time_mult) last_time = evt_t evt_s = evt_s_list[idx] current_iter = evt_mask[evt_s] current_patch = node_patches[evt_s] current_patch.set_alpha(1.0) current_node_text = node_text[evt_s] current_node_text.set_text(current_iter) if current_patch.get_axes() is None: cur_ax.add_patch(current_patch) cur_ax.add_artist(current_node_text) current_iter += 1 evt_mask[evt_s] = current_iter if evt_id == new_state_e: node_idx = state_n_list[new_state_idx] node_val = state_val_list[new_state_idx] cell_idx = cell_pos[node_idx] state_cells[(1, cell_idx)].get_text().set_text(str(node_val)) print("t=%g, node: %s, val: %s" % (evt_t, node_map[node_idx], node_val)) new_state_idx += 1 if new_state_idx < len_states: new_state_e = state_e_list[new_state_idx] else: break
def plot_gridpolicy(self, ax): ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) w = self.w d = 0.03 width = height = 1.0 / w p2g = lambda p: (p[1] * width + width / 2, 1 - p[0] * height - height / 2) for i in range(w): for j in range(w): tb.add_cell(i, j, width, height, text='') s = (i, j) m = self.Q_opt[s].max() dire = self.Q_opt[s] == m x, y = p2g(s) if all(dire == [1, 1, 1, 1]): ax.arrow(x, y, 0, d, head_width=d, fc='k') ax.arrow(x, y, 0, -d, head_width=d, fc='k') ax.arrow(x, y, d, 0, head_width=d, fc='k') ax.arrow(x, y, -d, 0, head_width=d, fc='k') elif all(dire == [1, 0, 0, 0]): # up ax.arrow(x, y - width / 2 + d, 0, width / 2, head_width=d, fc='k') elif all(dire == [0, 1, 0, 0]): # down ax.arrow(x, y + width / 2 - d, 0, -width / 2, head_width=d, fc='k') elif all(dire == [0, 0, 1, 0]): # <- ax.arrow(x + width / 2 - d, y, -width / 2, 0, head_width=d, fc='k') elif all(dire == [0, 0, 0, 1]): # -> ax.arrow(x - width / 2 + d, y, width / 2, 0, head_width=d, fc='k') elif all(dire == [1, 0, 0, 1]): ax.arrow(x - width / 2 + 2 * d, y - width / 2 + 2 * d, 0, width / 2 - d, head_width=d, fc='k') ax.arrow(x - width / 2 + 2 * d, y - width / 2 + 2 * d, width / 2 - d, 0, head_width=d, fc='k') elif all(dire == [1, 0, 1, 0]): ax.arrow(x + width / 2 - 2 * d, y - width / 2 + 2 * d, 0, width / 2 - d, head_width=d, fc='k') ax.arrow(x + width / 2 - 2 * d, y - width / 2 + 2 * d, -width / 2 + d, 0, head_width=d, fc='k') elif all(dire == [0, 1, 1, 0]): ax.arrow(x + width / 2 - 2 * d, y + width / 2 - 2 * d, -width / 2 + d, 0, head_width=d, fc='k') ax.arrow(x + width / 2 - 2 * d, y + width / 2 - 2 * d, 0, -width / 2 + d, head_width=d, fc='k') elif all(dire == [0, 1, 0, 1]): ax.arrow(x - width / 2 + 2 * d, y + width / 2 - 2 * d, width / 2 - d, 0, head_width=d, fc='k') ax.arrow(x - width / 2 + 2 * d, y + width / 2 - 2 * d, 0, -width / 2 + d, head_width=d, fc='k') ax.add_table(tb)
print(name, ' score: ', pilotRow) pilotRows.append(pilotRow) pilotDict[name] = pilotRow maxLength = 0 for i in pilotRows: if len(i) > maxLength: maxLength = len(i) fig = plt.figure(dpi=150) ax = fig.add_subplot(1, 1, 1) frame1 = plt.gca() frame1.axes.get_xaxis().set_ticks([]) frame1.axes.get_yaxis().set_ticks([]) tb = Table(ax, bbox=[0, 0, 1, 1]) tb.auto_set_font_size(False) n_cols = maxLength + 2 n_rows = len(pilots) + 1 width, height = 100 / n_cols, 100.0 / n_rows anchor = '⚓' #unicorn='✈️' blankcell = '#1A392A' colors = ['red', 'orange', 'orange', 'yellow', 'lightgreen'] minDate = data[-1]['ServerDate'] maxDate = data[0]['ServerDate'] textcolor = '#FFFFF0' edgecolor = '#708090'
def run(self): self.parent.clear_plots() if self.data is None: length, tMin, tMax, fMin, fMax, lMin, lMax, peakF, peakL, peakT = ( '', ) * 10 else: length = len(self.data) tMin = format_time(self.extent.tMin, True) tMax = format_time(self.extent.tMax, True) fMin = format_precision(self.settings, freq=self.extent.fMin, fancyUnits=True) fMax = format_precision(self.settings, freq=self.extent.fMax, fancyUnits=True) lMin = format_precision(self.settings, level=self.extent.lMin, fancyUnits=True) lMax = format_precision(self.settings, level=self.extent.lMax, fancyUnits=True) peak = self.extent.get_peak_flt() peakF = format_precision(self.settings, freq=peak[0], fancyUnits=True) peakL = format_precision(self.settings, level=peak[1], fancyUnits=True) peakT = format_time(peak[2], True) text = [ ['Sweeps', '', length], ['Extents', '', ''], ['', 'Start', tMin], ['', 'End', tMax], ['', 'Min frequency', fMin], ['', 'Max frequency', fMax], ['', 'Min level', lMin], ['', 'Max level', lMax], ['Peak', '', ''], ['', 'Level', peakL], ['', 'Frequency', peakF], ['', 'Time', peakT], ] table = Table(self.axes, loc='center') table.set_gid('table') rows = len(text) cols = len(text[0]) for row in xrange(rows): for col in xrange(cols): table.add_cell(row, col, text=text[row][col], width=1.0 / cols, height=1.0 / rows) if self.settings.grid: colour = 'LightGray' else: colour = 'w' set_table_colour(table, colour) for i in range(3): table.auto_set_column_width(i) self.axes.add_table(table) self.parent.redraw_plot() self.parent.threadPlot = None
def plot(self, experiment, plot_name = None, **kwargs): """Plot a table""" if experiment is None: raise util.CytoflowViewError('experiment', "No experiment specified") if self.statistic not in experiment.statistics: raise util.CytoflowViewError('statistic', "Can't find the statistic {} in the experiment" .format(self.statistic)) else: stat = experiment.statistics[self.statistic] data = pd.DataFrame(index = stat.index) data[stat.name] = stat if self.subset: try: data = data.query(self.subset) except Exception as e: raise util.CytoflowViewError('subset', "Subset string '{0}' isn't valid" .format(self.subset)) from e if len(data) == 0: raise util.CytoflowViewError('subset', "Subset string '{0}' returned no values" .format(self.subset)) names = list(data.index.names) for name in names: unique_values = data.index.get_level_values(name).unique() if len(unique_values) == 1: warn("Only one value for level {}; dropping it.".format(name), util.CytoflowViewWarning) try: data.index = data.index.droplevel(name) except AttributeError as e: raise util.CytoflowViewError(None, "Must have more than one " "value to plot.") from e if not (self.row_facet or self.column_facet): raise util.CytoflowViewError('row_facet', "Must set at least one of row_facet " "or column_facet") if self.subrow_facet and not self.row_facet: raise util.CytoflowViewError('subrow_facet', "Must set row_facet before using " "subrow_facet") if self.subcolumn_facet and not self.column_facet: raise util.CytoflowViewError('subcolumn_facet', "Must set column_facet before using " "subcolumn_facet") if self.row_facet and self.row_facet not in experiment.conditions: raise util.CytoflowViewError('row_facet', "Row facet {} not in the experiment, " "must be one of {}" .format(self.row_facet, experiment.conditions)) if self.row_facet and self.row_facet not in data.index.names: raise util.CytoflowViewError('row_facet', "Row facet {} not a statistic index; " "must be one of {}" .format(self.row_facet, data.index.names)) if self.subrow_facet and self.subrow_facet not in experiment.conditions: raise util.CytoflowViewError('subrow_facet', "Subrow facet {} not in the experiment, " "must be one of {}" .format(self.subrow_facet, experiment.conditions)) if self.subrow_facet and self.subrow_facet not in data.index.names: raise util.CytoflowViewError('subrow_facet', "Subrow facet {} not a statistic index; " "must be one of {}" .format(self.subrow_facet, data.index.names)) if self.column_facet and self.column_facet not in experiment.conditions: raise util.CytoflowViewError('column_facet', "Column facet {} not in the experiment, " "must be one of {}" .format(self.column_facet, experiment.conditions)) if self.column_facet and self.column_facet not in data.index.names: raise util.CytoflowViewError('column_facet', "Column facet {} not a statistic index; " "must be one of {}" .format(self.column_facet, data.index.names)) if self.subcolumn_facet and self.subcolumn_facet not in experiment.conditions: raise util.CytoflowViewError('subcolumn_facet', "Subcolumn facet {} not in the experiment, " "must be one of {}" .format(self.subcolumn_facet, experiment.conditions)) if self.subcolumn_facet and self.subcolumn_facet not in data.index.names: raise util.CytoflowViewError('subcolumn_facet', "Subcolumn facet {} not a statistic index; " "must be one of {}" .format(self.subcolumn_facet, data.index.names)) facets = [x for x in [self.row_facet, self.subrow_facet, self.column_facet, self.subcolumn_facet] if x] if len(facets) != len(set(facets)): raise util.CytoflowViewError(None, "Can't reuse facets") if set(facets) != set(data.index.names): raise util.CytoflowViewError(None, "Must use all the statistic indices as variables or facets: {}" .format(data.index.names)) row_groups = data.index.get_level_values(self.row_facet).unique() \ if self.row_facet else [None] subrow_groups = data.index.get_level_values(self.subrow_facet).unique() \ if self.subrow_facet else [None] col_groups = data.index.get_level_values(self.column_facet).unique() \ if self.column_facet else [None] subcol_groups = data.index.get_level_values(self.subcolumn_facet).unique() \ if self.subcolumn_facet else [None] row_offset = (self.column_facet != "") + (self.subcolumn_facet != "") col_offset = (self.row_facet != "") + (self.subrow_facet != "") num_cols = len(col_groups) * len(subcol_groups) + col_offset fig = plt.figure() ax = fig.add_subplot(111) # hide the plot axes that matplotlib tries to make ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) for sp in ax.spines.values(): sp.set_color('w') sp.set_zorder(0) loc = 'upper left' bbox = None t = Table(ax, loc, bbox, **kwargs) t.auto_set_font_size(False) for c in range(num_cols): t.auto_set_column_width(c) width = [0.2] * num_cols height = t._approx_text_height() * 1.8 # make the main table for (ri, r) in enumerate(row_groups): for (rri, rr) in enumerate(subrow_groups): for (ci, c) in enumerate(col_groups): for (cci, cc) in enumerate(subcol_groups): row_idx = ri * len(subrow_groups) + rri + row_offset col_idx = ci * len(subcol_groups) + cci + col_offset # this is not pythonic, but i'm tired agg_idx = [] for data_idx in data.index.names: if data_idx == self.row_facet: agg_idx.append(r) elif data_idx == self.subrow_facet: agg_idx.append(rr) elif data_idx == self.column_facet: agg_idx.append(c) elif data_idx == self.subcolumn_facet: agg_idx.append(cc) agg_idx = tuple(agg_idx) if len(agg_idx) == 1: agg_idx = agg_idx[0] try: text = "{:g}".format(data.loc[agg_idx][stat.name]) except ValueError: text = data.loc[agg_idx][stat.name] t.add_cell(row_idx, col_idx, width = width[col_idx], height = height, text = text) # row headers if self.row_facet: for (ri, r) in enumerate(row_groups): row_idx = ri * len(subrow_groups) + row_offset try: text = "{0} = {1:g}".format(self.row_facet, r) except ValueError: text = "{0} = {1}".format(self.row_facet, r) t.add_cell(row_idx, 0, width = width[0], height = height, text = text) # subrow headers if self.subrow_facet: for (ri, r) in enumerate(row_groups): for (rri, rr) in enumerate(subrow_groups): row_idx = ri * len(subrow_groups) + rri + row_offset try: text = "{0} = {1:g}".format(self.subrow_facet, rr) except ValueError: text = "{0} = {1}".format(self.subrow_facet, rr) t.add_cell(row_idx, 1, width = width[1], height = height, text = text) # column headers if self.column_facet: for (ci, c) in enumerate(col_groups): col_idx = ci * len(subcol_groups) + col_offset try: text = "{0} = {1:g}".format(self.column_facet, c) except ValueError: text = "{0} = {1}".format(self.column_facet, c) t.add_cell(0, col_idx, width = width[col_idx], height = height, text = text) # subcolumn headers if self.subcolumn_facet: for (ci, c) in enumerate(col_groups): for (cci, cc) in enumerate(subcol_groups): col_idx = ci * len(subcol_groups) + cci + col_offset try: text = "{0} = {1:g}".format(self.subcolumn_facet, cc) except ValueError: text = "{0} = {1}".format(self.subcolumn_facet, cc) t.add_cell(1, col_idx, width = width[col_idx], height = height, text = text) ax.add_table(t)
def plot_queens(queens_configuration: list, title: str, file_path_and_name: str = None) -> None: """ Plots a chess board of colors yellow and white, with the specified queen pieces marked with a 'Q' Can save the plot figure if a path is specified, otherwise it's plotted on a window :param queens_configuration: A list of length 2*<the amount of queens>, with their coordinates :param title: Title of the plot :param file_path_and_name: Path for saving the file """ queens_2 = list(queens_configuration) board_size = int(len(queens_configuration) / 2) _, ax = plt.subplots() ax.set_axis_off() table = Table(ax, bbox=[0, 0, 1, 1]) width = 1.0 / board_size height = 1.0 / board_size bkg_colors = ['yellow', 'white'] for i in range(board_size): for j in range(board_size): is_queen = False for k in range(int(len(queens_2) / 2)): x = queens_2[2 * k] y = queens_2[2 * k + 1] if i == x and j == y: is_queen = True queens_2.pop(2 * k) queens_2.pop(2 * k) break idx = [j % 2, (j + 1) % 2][i % 2] color = bkg_colors[idx] if is_queen: table.add_cell(i, j, width, height, text="Q", loc='center', facecolor=color) else: table.add_cell(i, j, width, height, loc='center', facecolor=color) for i in range(board_size): # Row Labels... table.add_cell(i, -1, width, height, text=i, loc='right', edgecolor='none', facecolor='none') # Column Labels... table.add_cell(-1, i, width, height / 2, text=i, loc='center', edgecolor='none', facecolor='none') ax.add_table(table) plt.title(title, y=1.08) # Save or plot if file_path_and_name is not None: plt.savefig(file_path_and_name, bbox_inches='tight') else: plt.show() plt.close() return
# Set grid to use minor tick locations. plt.grid(which='minor') savefig('figname.png', facecolor=fig.get_facecolor(), transparent=True) import matplotlib.pyplot as plt import numpy as np import pandas from matplotlib.table import Table data = pandas.DataFrame( grid, columns=['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']) data plt.show tb = Table(ax) def main(): data = pandas.DataFrame( grid, columns=['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']) checkerboard_table(data) plt.show() def checkerboard_table(data, fmt='{:.1f}', bkg_colors=['white', 'white']): fig, ax = plt.subplots() ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) nrows, ncols = data.shape
def shmoo_plotting(self): ''' pixel register shmoo plot ''' shmoopdf = PdfPages('shmoo.pdf') shmoonp = np.array(self.shmoo_errors) data = shmoonp.reshape(len(self.voltages),-1,order='F') fig, ax = plt.subplots() plt.title('Pixel registers errors') ax.set_axis_off() fig.text(0.70, 0.05, 'SPI clock (MHz)', fontsize=14) fig.text(0.02, 0.90, 'Supply voltage (V)', fontsize=14, rotation=90) tb = Table(ax, bbox=[0.01,0.01,0.99,0.99]) ncols = len(self.bitfiles) nrows = len(self.voltages) width, height = 1.0 / ncols, 1.0 / nrows # Add cells for (i,j), val in np.ndenumerate(data): color = '' if val == 0: color = 'green' if (val > 0 & val < 10): color = 'yellow' if val > 10: color = 'red' tb.add_cell(i, j, width, height, text=str(val), loc='center', facecolor=color) # Row Labels... for i in range(len(self.voltages)): tb.add_cell(i, -1, width, height, text=str(self.voltages[i]), loc='right', edgecolor='none', facecolor='none') # Column Labels... for j in range(len(self.bitfiles)): newlabel1 = str(self.bitfiles[j][-9:-7]).replace("_","") tb.add_cell(nrows+1, j, width, height/2, text=newlabel1, loc='center', edgecolor='none', facecolor='none') ax.add_table(tb) shmoopdf.savefig() ''' global register shmoo plot ''' shmoo_glob_np = np.array(self.shmoo_global_errors) data_g = shmoo_glob_np.reshape(len(self.voltages),-1,order='F') fig_g, ax_g = plt.subplots() ax_g.set_axis_off() fig_g.text(0.70, 0.05, 'SPI clock (MHz)', fontsize=14) fig_g.text(0.02, 0.90, 'Supply voltage (V)', fontsize=14, rotation=90) tb_g = Table(ax_g, bbox=[0.01,0.01,0.99,0.99]) plt.title('Global registers errors') # Add cells for (i,j), val_g in np.ndenumerate(data_g): color = '' if val_g == 0: color = 'green' if val_g > 0: color = 'red' tb_g.add_cell(i, j, width, height, text=str(val_g), loc='center', facecolor=color) # Row Labels... for i in range(len(self.voltages)): tb_g.add_cell(i, -1, width, height, text=str(self.voltages[i]), loc='right', edgecolor='none', facecolor='none') # Column Labels... for j in range(len(self.bitfiles)): newlabel = str(self.bitfiles[j][-9:-7]).replace("_","") tb_g.add_cell(nrows+1, j, width, height/2, text=newlabel, loc='center', edgecolor='none', facecolor='none') ax_g.add_table(tb_g) shmoopdf.savefig() shmoopdf.close()
fig, axs = plt.subplots(len(k_values), figsize=(4, 16)) for i, k in enumerate(k_values): values = dict((state, 0) for state in states) values[(0, 0)] = 0 values[(3, 3)] = 0 policy = 0.25 for _ in range(k): # Two-array version of the iterative policy evaluation thanks to ShangtongZhang's repo # https://github.com/ShangtongZhang/reinforcement-learning-an-introduction old_values = deepcopy(values) for state in states: prev_value = values[state] values[state] = update_value_out(old_values, state, discount=1) tb = Table(axs[i], bbox=[0, 0, 1, 1]) for (x, y), val in values.items(): tb.add_cell(x, y, 1 / 4, 1 / 4, text=round(val, 1), loc='center', facecolor='white') axs[i].title.set_text(f'k = {k}') axs[i].add_table(tb) axs[i].set_axis_off() plt.show()
def sort_fellows_table_nproj_pdf(names, jobs, companies, subjects, unis, degs, name_plot): table = OrderedDict( (('Name', names), ('Job', jobs), ('Company', companies), ('Subject', subjects), ('University', unis), ('Degree', degs))) Headn = ['Name', 'Job', 'Company', 'Subject', 'University', 'Degree'] data = pandas.DataFrame(table) #print data fig, ax = plt.subplots() ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) nrows, ncols = data.shape #fig = plt.figure(figsize=(5, 20)) width, height = 1.0 / (ncols), 1.0 / (nrows) matplotlib.rcParams.update({'font.size': 20}) print nrows, ncols # Add cells print data[Headn[1]].iloc[0] for i in range(0, ncols): for j in range(0, nrows - 1): txt = "\n".join(wrap(data[Headn[i]].iloc[j], 40)) color = "white" tb.add_cell(j, i, width, height, text=txt, loc='center', facecolor="white") # Row Labels... for i in range(0, nrows - 1): label = str(i) tb.add_cell(i, -1, width, height, text=label, loc='right', edgecolor='black', facecolor='lightblue') # Column Labels... for j, label in enumerate(data.columns): tb.add_cell(-1, j, width, height / 2.0, text=label, loc='center', edgecolor='black', facecolor='lightcoral') ax.add_table(tb) plt.savefig(name_plot + '.pdf', bbox_inches="tight") return ''
def plot_models_ranges(allFRdata, legend, models=np.arange(0, 15, 1), filename=None): fig = plt.figure(1) plot_size = (8, 10) ax = plt.subplot2grid(plot_size, (0, 0), colspan=5, rowspan=8) plt.title("Models'results for each range") ax.set_axis_off() tb = Table(ax, bbox=[0.1, 0, 1.7, 1.]) width = 0.25 height = 1.0 / len(models) modLbl = [] for modInd, model in enumerate(models): frates = allFRdata[model] for frInd, frline in enumerate(frates): row = modInd + frInd score = 0 score_max = 0 col = 0 labels = ["Model"] modLbl.append(model) frline = frline.split(',')[1:-1] # for each Nucleus, getting the results for nres in frline: score_max += 1 nres = nres.strip().split('=') labels.append(nres[0]) if nres[1] == "OK": score += 1 color = '#BFE8B7' # green else: color = '#E8B7B7' # red tb.add_cell(row, col, width, height, text=nres[1], loc='center', facecolor=color) col += 1 tb.add_cell(row, col, width, height, text=str(score), loc='center', facecolor='white') col += 1 labels.append("Score/" + str(score_max)) # Row Labels... for i, label in enumerate(modLbl): tb.add_cell(i, -1, width, height, text=label, loc='right', edgecolor='none', facecolor='none') # Column Labels... global NUCLEI for j, label in enumerate(labels): if not (label in NUCLEI or label == "Model"): label = label.split("_") if len(label) == 1: # pour le score label = label[0].split("/") label[1] = "/" + label[1] label = label[0] + "\n" + label[1] tb.add_cell(len(modLbl), j - 1, 0.2, height * 2, text=label, loc='center', edgecolor='none', facecolor='none') '''ax.annotate(label,xy=(j*0.1,0),xycoords='axes fraction', ha='right',va='top',rotation=80,size=8)''' ax.add_table(tb) plot_param_legend(legend, plot_size, (1, 8)) fig.canvas.set_window_title("Passing LG14's tests") fig.tight_layout() fig.set_size_inches(w=11, h=7) if (not filename is None): plt.savefig(filename) else: plt.show()
def sort_fellows_table_pdf(names, jobs, companies, projects, subjects, unis, degs, name_plot): table = OrderedDict( (('Name', names), ('Job', jobs), ('Company', companies), ('Project', projects), ('Subject', subjects), ('University', unis), ('Degree', degs))) Headn = [ 'Name', 'Job', 'Company', 'Project', 'Subject', 'University', 'Degree' ] data = pandas.DataFrame(table) #print data fig, ax = plt.subplots() ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) nrows, ncols = data.shape wcell = 1 hcell = 1 wpad = 0.5 hpad = 0.5 print "Making table" #fig = plt.figure(figsize=(5, 20)) width, height = 1.0 / (ncols * 1.0), 1.0 / (nrows * 2.0) matplotlib.rcParams.update({'font.size': 50}) print nrows, ncols # Add cells print data[Headn[1]].iloc[0] for i in range(0, ncols): for j in range(0, nrows - 1): # Index either the first or second item of bkg_colors based on # a checker board pattern #idx = [j % 2, (j + 1) % 2][i % 2] txt = "\n".join(wrap(data[Headn[i]].iloc[j], 40)) nmrows = len(txt) / 20.0 print len(txt), nmrows #raw_input('check') color = "white" tb.add_cell(j, i, width, height * 2.0, text=txt, loc='center', facecolor="white") # Row Labels... for i in range(0, nrows - 1): label = str(i) tb.add_cell(i, -1, width, height * 2.0, text=label, loc='right', edgecolor='black', facecolor='lightblue') # Column Labels... for j, label in enumerate(data.columns): tb.add_cell(-1, j, width, height, text=label, loc='center', edgecolor='black', facecolor='lightcoral') ax.add_table(tb) plt.savefig(name_plot + '.pdf', bbox_inches="tight") return ''
def draw_policy_image(iteration, policy_image, env): fig, ax = plt.subplots() plt.suptitle('Policy Improvement: Iteration:{:d}'.format(iteration)) ax.set_axis_off() tb = Table(ax, bbox=[0, 0, 1, 1]) nrows, ncols, nactinos = policy_image.shape width, height = 1.0 / ncols, 1.0 / nrows # Add cells for i in range(nrows): for j in range(ncols): if env.is_terminal([i, j]): tb.add_cell(i, j, height, width, text=' ', loc='center', facecolor='white') elif env.is_on_obstacle([i, j]): tb.add_cell(i, j, height, width, text='╳', loc='center', facecolor='white') else: actions = (np.where(policy_image[i, j, :] != 0)[0]).tolist() actions_text = ''.join(ACTION_SYMBOLS[x] for x in actions) tb.add_cell(i, j, height, width, text=actions_text, loc='center', facecolor='white') # Row and column labels... for i in range(nrows): tb.add_cell(i, -1, height, width, text=i + 1, loc='right', edgecolor='none', facecolor='none') for i in range(ncols): tb.add_cell(nrows, i, height, width / 2, text=i + 1, loc='center', edgecolor='none', facecolor='none') ax.add_table(tb) plt.show()
def label_manual(self, display_elems, figsize=(15, 7.5), title="Sleep Stages"): """ Displays dialog for manual stage labeling args: displa_elems: list of lists of parameters used for display in format ((data_struct,{'parname1':parval1,'parname2':parval2,...}),...) where data_struct is of class containing plot method (such as EEGSpectralData) and 'parname1':parval1,... are name-value pairs supplied to the function. figsize: Size of the figure block: Blocks code execution until label dialog is closed """ self.saving = False sleep_stage_labels = [ 'NREM3', 'NREM2', 'REM', 'NREM1', 'WAKE', 'MASK OFF', '???' ] height_ratios = np.ones(len(display_elems)) * 3 height_ratios = np.append(height_ratios, 1) fig = plt.figure(figsize=figsize) gs = gridspec.GridSpec(len(display_elems) + 1, 1, height_ratios=height_ratios) ax_transforms = {} def format_time_period(val): m, s = divmod(int(round(val)), 60) h, m = divmod(m, 60) return "%d:%02d:%02d" % (h, m, s) def reset_labels(event=None): self.stage_times = [0] self.stage_labels = [5] if event is not None: redraw_labels(event) def reload_labels(event=None): self.stage_times = np.append([], self.loaded_stage_times) self.stage_labels = np.append([], self.loaded_stage_labels) if self.stage_times[0] is None or self.stage_labels[ 0] is None or self.stage_times.size != self.stage_labels.size: print("data is bad, resetting labels") reset_labels(event) else: print("stage_times size is " + str(self.stage_times.size)) print("stage_labels size is " + str(self.stage_labels.size)) if event is not None: redraw_labels(event) def redraw_labels(event=None): line1.set_xdata( np.concatenate((self.stage_times, [self.sleep_length]))) line1.set_ydata( np.concatenate((self.stage_labels, [self.stage_labels[-1]]))) data = [0, 0, 0, 0, 0, 0, 0, 0] for x in range(0, len(self.stage_times)): data_offset = int( len(sleep_stage_labels) - (self.stage_labels[x] + 1)) thisval = self.stage_times[x] nextval = self.sleep_length if x is len( self.stage_times) - 1 else self.stage_times[x + 1] diffval = nextval - thisval data[data_offset] = data[data_offset] + diffval data[len(data) - 1] = self.sleep_length for x in range(0, len(data)): table.get_celld()[x, 1].get_text().set_text( format_time_period(data[x])) fig.canvas.draw() def on_pick(figure, mouseevent): #print(ax_transforms) #print(mouseevent) #print(figure) #print(figure.get_axes()) if not mouseevent.inaxes in ax_transforms: return False, {} eegdata = ax_transforms[mouseevent.inaxes] xmouse, ymouse = mouseevent.xdata, mouseevent.ydata xmouse = eegdata.index_to_time(xmouse) #print('x, y of mouse: {:.2f},{:.2f}'.format(xmouse, ymouse)) larger = [ x[0] for x in enumerate(self.stage_times) if x[1] > xmouse ] if len(larger) > 0: idx = larger[0] if self.stage_labels[idx - 1] != self.stage_label: self.stage_times[idx] = xmouse self.stage_labels[idx] = self.stage_label else: self.stage_times = np.delete(self.stage_times, idx) self.stage_labels = np.delete(self.stage_labels, idx) else: if self.stage_labels[-1] != self.stage_label: self.stage_times = np.append(self.stage_times, xmouse) self.stage_labels = np.append(self.stage_labels, self.stage_label) #print(stage_times) #print(stage_labels) for i in range(1, len(self.stage_labels) - 1): if self.stage_labels[i] == self.stage_labels[i - 1]: self.stage_labels = np.delete(self.stage_labels, i) self.stage_times = np.delete(self.stage_times, i) redraw_labels() return True, {} for did in range(len(display_elems)): delem = display_elems[did] ax = plt.subplot(gs[did]) ax_transforms[ax] = delem[0] params = delem[1] params['axes'] = ax delem[0].plot(**params) if did == 0: plt.title(title) #TODO ax.set_picker(True) fig.canvas.mpl_connect('pick_event', on_pick) fig.set_picker(on_pick) xtickspacing = 300 if len(np.arange(0, self.sleep_length, 300)) > 20: xtickspacing = 600 if len(np.arange(0, self.sleep_length, 600)) > 20: xtickspacing = 1200 if len(np.arange(0, self.sleep_length, 1200)) > 20: xtickspacing = 1800 if len(np.arange(0, self.sleep_length, 1800)) > 20: xtickspacing = 3600 xticks = np.arange(0, self.sleep_length, xtickspacing) xticklabels = [str(int(i / 60)) for i in xticks] reload_labels() ax1 = plt.subplot(gs[-1]) line1, = ax1.plot(np.concatenate( (self.stage_times, [self.sleep_length])), np.concatenate( (self.stage_labels, [self.stage_labels[-1]])), drawstyle="steps-post") ax1.set_xlabel("Time (min)") ax1.set_ylabel("Sleep Stage") ax1.set_xlim(0, self.sleep_length) ax1.set_yticks(np.arange(7)) ax1.set_yticklabels(sleep_stage_labels) ax1.set_xticks(xticks) ax1.set_xticklabels(xticklabels) self.stage_label = 6 rax = plt.axes([0.0, 0.0, 0.2, 0.16], facecolor='lightgoldenrodyellow') radio = RadioButtons(rax, sleep_stage_labels[::-1], active=0) def stagepicker(label): self.stage_label = sleep_stage_labels.index(label) #print(plot_eeg_log_hist.stage_label) #fig.canvas.draw_idle() def done(event): self.saving = True plt.close() if self.loaded_stage_labels is not None: axreload = plt.axes([0.7, 0.0, 0.1, 0.075]) breload = Button(axreload, 'Reload') breload.on_clicked(reload_labels) axreset = plt.axes([0.8, 0.0, 0.1, 0.075]) breset = Button(axreset, 'Reset') breset.on_clicked(reset_labels) axdone = plt.axes([0.9, 0.0, 0.1, 0.075]) bdone = Button(axdone, 'Save &\n Quit') bdone.on_clicked(done) radio.on_clicked(stagepicker) tableax = plt.axes([0.2, 0.0, 0.25, 0.16], facecolor='lightblue') tableax.get_yaxis().set_visible(False) table = Table(tableax, bbox=[0, 0, 1, 1]) height = table._approx_text_height() lidx = 0 for label in sleep_stage_labels[::-1]: table.add_cell(lidx, 0, width=0.6, height=height, text=label) table.add_cell(lidx, 1, width=0.4, height=height, text='') lidx = lidx + 1 table.add_cell(lidx, 0, width=0.6, height=height, text='Total Sleep Time') table.add_cell(lidx, 1, width=0.4, height=height, text='') tableax.add_table(table) fig.canvas.callbacks.connect('pick_event', on_pick) fig.canvas.set_window_title('EEG Spectrogram Analysis') plt.subplots_adjust(left=0.15 if figsize[0] < 10 else 0.075, bottom=0.2, right=0.99, top=0.97) redraw_labels() plt.show() self.stage_times = np.array(self.stage_times) self.stage_times = np.concatenate( (self.stage_times, [self.sleep_length])) self.stage_labels = np.concatenate((self.stage_labels, [6]))
def visualise(self): """ Visualises the result of analysing the Gridworld. """ # Obtain the current working directory. curr_dir = os.path.dirname(os.path.abspath(__file__)) # Creates the required diagram and assigns a title to it, which includes # the number of iterations performed as part of the specified method. fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5), constrained_layout=True) graph_title = ' '.join( [substr.title() for substr in self.solve_method.split('_')]) fig.suptitle('Gridworld %s Results: %i Iterations' % (graph_title, self.current_iter), fontsize=30) # Disable the display of every axes of every subplot. ax1.set_axis_off() ax2.set_axis_off() ax3.set_axis_off() # Creates a table, representing a grid, for every subplot. tb1 = Table(ax1, bbox=[0, 0, 1, 1]) tb2 = Table(ax2, bbox=[0, 0, 1, 1]) tb3 = Table(ax3, bbox=[0, 0, 1, 1]) # Computes the height and width of each cell of the grid during # visualisation. height, width = 1.0 / self.size[0], 1.0 / self.size[1] # Loops across all possible states of the Gridworld. for (j, i), val in np.ndenumerate(self.v): # Creates a cell in every table for each state, with the cell in the # leftmost table containing the approximate final value estimate and # the cells in the other tables being empty. tb1.add_cell(j, i, width, height, text=val, loc='center', facecolor='none') tb2.add_cell(j, i, width, height, loc='center', facecolor='none') tb3.add_cell(j, i, width, height, loc='center', facecolor='none') # Loops across all possible actions of the current state. for (a, ), pol in np.ndenumerate(self.pi[j, i]): # Adds the action to the rightmost table if the action is # part of the final policy. The action is shown in the form # of an arrow. if pol: # Finds the corresponding coordinate change of the current # action. current_action = self.actions_list[a] # Adds the arrow to the rightmost table. ax3.arrow(width * (i + 0.5), 1 - height * (j + 0.5), 0.25 * current_action[1] * width, -0.25 * current_action[0] * height, head_width=0.1 * width, head_length=0.1 * height, color='black') # Loops across all possible reward paths in the reward dictionary. for ((j, i), a), (reward, new_coord) in self.reward_dict.items(): # Finds the corresponding coordinate change of the current # action. current_action = self.actions_list[a] if new_coord != (j, i): # Draws an arrow from the input state to the new state in the # central table. # Computes the start and delta coordinates of the arrow. x_arrow_start = width * (i + 0.5) delta_x_arrow = width * (new_coord[1] - i) y_arrow_start = 1 - height * (j + 0.5) delta_y_arrow = height * (j - new_coord[0]) ax2.arrow(x_arrow_start, y_arrow_start, delta_x_arrow, delta_y_arrow, head_width=0.1 * width, head_length=0.1 * height, color='black') # Computes the coordinates of the arrow reward text. x_text = x_arrow_start + delta_x_arrow / 2. + width * ( current_action[0] * 0.175 + current_action[1] * 0.175) y_text = y_arrow_start + delta_y_arrow / 2. + height * ( -current_action[0] * 0.175 + current_action[1] * 0.175) # Draws the reward value of taking the action to the central # table. ax2.text(x_text, y_text, str(round(reward, 2)), horizontalalignment='center', verticalalignment='center', fontsize=9, color='red', fontweight='bold') else: # Draws an arrow from the input state to itself in the central # table. # Computes the centre coordinates of the arrow arc. x_center = width * (i + 0.5 + current_action[1] * 0.2) y_center = 1 - height * (j + 0.5 + current_action[0] * 0.2) # Computes the start and end coordinates of the arrow arc. x_arrow_start = x_center - width * current_action[0] * 0.175 x_arrow_end = x_center + width * current_action[0] * 0.175 y_arrow_start = y_center - height * current_action[1] * 0.175 y_arrow_end = y_center + height * current_action[1] * 0.175 # Adds the arrow to the central table. ax2.add_patch( mpatches.FancyArrowPatch( (x_arrow_start, y_arrow_start), (x_arrow_end, y_arrow_end), edgecolor='black', facecolor='black', arrowstyle=mpatches.ArrowStyle.Fancy(head_length=4, head_width=4), connectionstyle="arc3,rad=0.9")) # Using the centre coordinates of the arrow arc, draws the # reward value of taking the action to the central table. ax2.text(x_center, y_center, str(round(reward, 2)), horizontalalignment='center', verticalalignment='center', fontsize=9, color='red', fontweight='bold') # Adds the row indexes for each of the tables. for j in range(self.size[0]): tb1.add_cell(j, -1, width / 2, height, text=j, loc='right', edgecolor='none', facecolor='none') tb2.add_cell(j, -1, width / 2, height, text=j, loc='right', edgecolor='none', facecolor='none') tb3.add_cell(j, -1, width / 2, height, text=j, loc='right', edgecolor='none', facecolor='none') # Adds the column indexes for each of the tables. for i in range(self.size[1]): tb1.add_cell(-1, i, width, height / 4, text=i, loc='center', edgecolor='none', facecolor='none') tb2.add_cell(-1, i, width, height / 4, text=i, loc='center', edgecolor='none', facecolor='none') tb3.add_cell(-1, i, width, height / 4, text=i, loc='center', edgecolor='none', facecolor='none') # Adds the completed tables to their respective subplots. ax1.add_table(tb1) ax2.add_table(tb2) ax3.add_table(tb3) # Sets the titles for all three subplots. title_height = 1.05 title_size = 15 ax1.set_title('Optimum Value', y=title_height, size=title_size) ax2.set_title('Reward Map', y=title_height, size=title_size) ax3.set_title('Optimum Path', y=title_height, size=title_size) # Saves the plotted diagram with the name of the selected method. plt.savefig( os.path.join( curr_dir, 'gridworld_%s_%s_results.png' % (self.name, self.solve_method))) plt.close()
def mytable(ax, csize=0.04, cellText=None, cellColours=None, cellLoc='right', colWidths=None, rowLabels=None, rowColours=None, rowLoc='left', colLabels=None, colColours=None, colLoc='center', loc='bottom', bbox=None): # Check we have some cellText if cellText is None: # assume just colours are needed rows = len(cellColours) cols = len(cellColours[0]) cellText = [[''] * rows] * cols rows = len(cellText) cols = len(cellText[0]) for row in cellText: assert len(row) == cols if cellColours is not None: assert len(cellColours) == rows for row in cellColours: assert len(row) == cols else: cellColours = ['w' * cols] * rows # Set colwidths if not given if colWidths is None: colWidths = [1.0 / cols] * cols # Check row and column labels rowLabelWidth = 0 if rowLabels is None: if rowColours is not None: rowLabels = [''] * cols rowLabelWidth = colWidths[0] elif rowColours is None: rowColours = 'w' * rows if rowLabels is not None: assert len(rowLabels) == rows offset = 0 if colLabels is None: if colColours is not None: colLabels = [''] * rows offset = 1 elif colColours is None: colColours = 'w' * cols offset = 1 if rowLabels is not None: assert len(rowLabels) == rows # Set up cell colours if not given if cellColours is None: cellColours = ['w' * cols] * rows # Now create the table table = Table(ax, loc, bbox) height = csize # Add the cells for row in xrange(rows): for col in xrange(cols): table.add_cell(row + offset, col, width=height, height=height, text=cellText[row][col], facecolor=cellColours[row][col], loc=cellLoc) ax.add_table(table) return table