Exemple #1
0
def task_table_plot(task_data):
    """
    Plot task table, save in png file.
    :param task_data: dataframe from get_task_table().
    """
    groups = task_data.Group.values
    task_no_group = task_data.drop('Group', axis=1)
    nrows, ncols = task_no_group.shape
    width, height = 1.0 / ncols, 1.0 / nrows

    fig, ax = plt.subplots(figsize=(1, nrows*0.25))
    ax.set_axis_off()
    tbl = Table(ax)
    tbl.auto_set_font_size(False)
    # Columns width for non-auto-width columns
    col_widths = [1, 1, 0.5, 1, 0.7, 0.7, 0.7, 0.7, 0.7]
    palette = get_palette()
    fontcolor = 'w'
    for (i, j), val in np.ndenumerate(task_no_group):
        fc = palette[groups[i]]
        fontsize = 10
        if j < 2:
            loc = 'left'
            font_family = None
            if j == 0:
                fontsize = 9
        else:
            loc = 'center'
            #font_family = 'DINPro'
            if j > 3:
                fontsize = 9
        tbl.add_cell(i, j, col_widths[j], height, text=val,
                     loc=loc, facecolor=fc, edgecolor=fontcolor)
        cell = tbl.get_celld()[(i, j)]
        cell.set_linewidth(0.5)
        cell.set_text_props(color=fontcolor, family=font_family, weight='bold', fontsize=fontsize)

    # Column Labels...
    for j, label in enumerate(task_no_group.columns):
        tbl.add_cell(-1, j, col_widths[j], height*0.8, text=label, loc='center',
                     facecolor='gray', edgecolor='w')
        cell = tbl.get_celld()[(-1, j)]
        cell.set_linewidth(0.5)
        cell.set_text_props(color=fontcolor, weight='bold', family='Verdana', fontsize=9)

    tbl._autoColumns = [0, 1]
    tbl.scale(1, 1.5)  # scale y to cover blank in the bottom
    ax.add_table(tbl)
    ax.margins(0, 0)
    fig.savefig('img/task_table', bbox_inches='tight', pad_inches=0.1, dpi=200)
Exemple #2
0
def task_table_plot(task_data):
    """
    Plot task table, save in png file.
    :param task_data: dataframe from get_task_table().
    """
    groups = task_data.Group.values
    task_no_group = task_data.drop('Group', axis=1)
    nrows, ncols = task_no_group.shape
    width, height = 1.0 / ncols, 1.0 / nrows

    fig, ax = plt.subplots(figsize=(1, nrows*0.25))
    ax.set_axis_off()
    tbl = Table(ax)
    tbl.auto_set_font_size(False)
    # Columns width for non-auto-width columns
    col_widths = [1, 1, 0.5, 1, 0.7, 0.7, 0.7, 0.7, 0.7]
    palette = get_palette()
    fontcolor = 'w'
    for (i, j), val in np.ndenumerate(task_no_group):
        fc = palette[groups[i]]
        fontsize = 10
        if j < 2:
            loc = 'left'
            font_family = None
            if j == 0:
                fontsize = 9
        else:
            loc = 'center'
            font_family = 'DINPro'
            if j > 3:
                fontsize = 9
        tbl.add_cell(i, j, col_widths[j], height, text=val,
                     loc=loc, facecolor=fc, edgecolor=fontcolor)
        cell = tbl.get_celld()[(i, j)]
        cell.set_linewidth(0.5)
        cell.set_text_props(color=fontcolor, family=font_family, weight='bold', fontsize=fontsize)

    # Column Labels...
    for j, label in enumerate(task_no_group.columns):
        tbl.add_cell(-1, j, col_widths[j], height*0.8, text=label, loc='center',
                     facecolor='gray', edgecolor='w')
        cell = tbl.get_celld()[(-1, j)]
        cell.set_linewidth(0.5)
        cell.set_text_props(color=fontcolor, weight='bold', family='Verdana', fontsize=9)

    tbl._autoColumns = [0, 1]
    tbl.scale(1, 1.5)  # scale y to cover blank in the bottom
    ax.add_table(tbl)
    ax.margins(0, 0)
    fig.savefig('img/task_table', bbox_inches='tight', pad_inches=0.1, dpi=200)
def draw_q_value_image(dyna_maze, q_value, run, planning_steps, episode):
    # 축 표시 제거, 크기 조절 등 이미지 그리기 이전 설정 작업
    fig, axis = plt.subplots()
    axis.set_axis_off()
    table = Table(axis, bbox=[0, 0, 1, 1])

    num_rows, num_cols = dyna_maze.MAZE_HEIGHT, dyna_maze.MAZE_WIDTH
    width, height = 1.0 / num_cols, 1.0 / num_rows

    for i in range(dyna_maze.MAZE_HEIGHT):
        for j in range(dyna_maze.MAZE_WIDTH):
            if np.sum(q_value[i][j]) == 0.0:
                symbol = " "
            else:
                action_idx = np.argmax(q_value[i][j])
                symbol = dyna_maze.ACTION_SYMBOLS[action_idx]
            table.add_cell(i, j, width, height, text=symbol, loc='center', facecolor='white')

    # 행, 열 라벨 추가
    for i in range(dyna_maze.MAZE_HEIGHT):
        table.add_cell(i, -1, width, height, text=i, loc='right', edgecolor='none', facecolor='none')

    for j in range(dyna_maze.MAZE_WIDTH):
        table.add_cell(-1, j, width, height/2, text=j, loc='center', edgecolor='none', facecolor='none')

    for key, cell in table.get_celld().items():
         cell.get_text().set_fontsize(20)

    axis.add_table(table)
    plt.savefig('images/maze_action_values_{0}_{1}_{2}.png'.format(run, planning_steps, episode))
    plt.close()
Exemple #4
0
def draw_grid_world_action_values_image(action_values, filename, GRID_HEIGHT,
                                        GRID_WIDTH, NUM_ACTIONS,
                                        ACTION_SYMBOLS):
    action_str_values = []
    for i in range(GRID_HEIGHT):
        action_str_values.append([])
        for j in range(GRID_WIDTH):
            str_values = []
            for action in range(NUM_ACTIONS):
                str_values.append("{0} ({1}): {2:.2f}".format(
                    ACTION_SYMBOLS[action], action, action_values[i, j,
                                                                  action]))
            action_str_values[i].append("\n".join(str_values))

    # 축 표시 제거, 크기 조절 등 이미지 그리기 이전 설정 작업
    fig, ax = plt.subplots()
    ax.set_axis_off()
    table = Table(ax, bbox=[0, 0, 1, 1])

    nrows, ncols = GRID_HEIGHT, GRID_WIDTH
    width, height = 1.0 / ncols, 1.0 / nrows

    # 렌더링 할 이미지에 표 셀과 해당 값 추가
    for i in range(GRID_HEIGHT):
        for j in range(GRID_WIDTH):
            table.add_cell(i,
                           j,
                           width,
                           height,
                           text=action_str_values[i][j],
                           loc='center',
                           facecolor='white')

    # 행, 열 라벨 추가
    for i in range(len(action_str_values)):
        table.add_cell(i,
                       -1,
                       width,
                       height,
                       text=i + 1,
                       loc='right',
                       edgecolor='none',
                       facecolor='none')
        table.add_cell(-1,
                       i,
                       width,
                       height / 2,
                       text=i + 1,
                       loc='center',
                       edgecolor='none',
                       facecolor='none')

    for key, cell in table.get_celld().items():
        cell.get_text().set_fontsize(10)

    ax.add_table(table)

    plt.savefig(filename)
    plt.close()
Exemple #5
0
def draw_grid_world_image(values, filename, grid_height, grid_width):
    state_values = np.zeros((grid_height, grid_width))
    for i in range(grid_height):
        for j in range(grid_width):
            state_values[(i, j)] = values[(i, j)]

    state_values = np.round(state_values, decimals=2)

    # 축 표시 제거, 크기 조절 등 이미지 그리기 이전 설정 작업
    fig, ax = plt.subplots()
    ax.set_axis_off()
    table = Table(ax, bbox=[0, 0, 1, 1])

    nrows, ncols = state_values.shape
    width, height = 1.0 / ncols, 1.0 / nrows

    # 렌더링 할 이미지에 표 셀과 해당 값 추가
    for (i, j), val in np.ndenumerate(state_values):
        table.add_cell(i,
                       j,
                       width,
                       height,
                       text=val,
                       loc='center',
                       facecolor='white')

    # 행, 열 라벨 추가
    for i in range(len(state_values)):
        table.add_cell(i,
                       -1,
                       width,
                       height,
                       text=i + 1,
                       loc='right',
                       edgecolor='none',
                       facecolor='none')
        table.add_cell(-1,
                       i,
                       width,
                       height / 2,
                       text=i + 1,
                       loc='center',
                       edgecolor='none',
                       facecolor='none')

    for key, cell in table.get_celld().items():
        cell.get_text().set_fontsize(20)

    ax.add_table(table)

    plt.savefig(filename)
    plt.close()
Exemple #6
0
def plot_table(ax, data, float_format='{:.2f}', columns=None, bbox=None, cell_font_size=6):
    if columns is not None:
        if type(columns) is dict:
            data = data[columns.keys()]
            
            data = data.rename(columns)
        
        else:
            data = data[columns]
    
    if bbox is None:
        bbox = [0, 0, 1, 1]
    
    ax.set_axis_off()
    
    table = Table(ax, bbox=bbox)

    num_rows, num_cols = data.shape
    
    width = 1.0 / num_cols
    
    #height = 1.0 / num_rows 
    height = table._approx_text_height()

    # Add cells
    for (i,j), val in np.ndenumerate(data):
        val = float_format.format(val)
        
        table.add_cell(i, j, width, height, text=val, loc='center')
    
    cells = table.get_celld()
    
    for index, val in np.ndenumerate(data):
        cells[index].set_fontsize(cell_font_size)

    # Row Labels...
    for i, label in enumerate(data.index):
        table.add_cell(i, -1, width, height, text=label, loc='right', edgecolor='none', facecolor='none')
    
    # Column Labels...
    for j, label in enumerate(data.columns):
        table.add_cell(-1, j, width, height/2, text=label, loc='center', edgecolor='none', facecolor='none')
    
    ax.add_table(table)
    
    table.set_fontsize(cell_font_size)
    
    return ax
Exemple #7
0
def draw_grid_world_state_values_image(state_values, filename, GRID_HEIGHT,
                                       GRID_WIDTH):
    # 축 표시 제거, 크기 조절 등 이미지 그리기 이전 설정 작업
    fig, ax = plt.subplots()
    ax.set_axis_off()
    table = Table(ax, bbox=[0, 0, 1, 1])

    nrows, ncols = GRID_HEIGHT, GRID_WIDTH
    width, height = 1.0 / ncols, 1.0 / nrows

    # 렌더링 할 이미지에 표 셀과 해당 값 추가
    for i in range(GRID_HEIGHT):
        for j in range(GRID_WIDTH):
            table.add_cell(i,
                           j,
                           width,
                           height,
                           text=np.round(state_values[i][j], decimals=2),
                           loc='center',
                           facecolor='white')

    # 행, 열 라벨 추가
    for i in range(len(state_values)):
        table.add_cell(i,
                       -1,
                       width,
                       height,
                       text=i + 1,
                       loc='right',
                       edgecolor='none',
                       facecolor='none')
        table.add_cell(-1,
                       i,
                       width,
                       height / 2,
                       text=i + 1,
                       loc='center',
                       edgecolor='none',
                       facecolor='none')

    for key, cell in table.get_celld().items():
        cell.get_text().set_fontsize(20)

    ax.add_table(table)

    plt.savefig(filename)
    plt.close()
Exemple #8
0
def __draw_state_transition__(tran_list,
                              state_list,
                              node_map,
                              node_patches,
                              node_text,
                              cur_ax,
                              time_mult=0.1):
    """
    Draw the state transition
    """
    import time
    from numpy import zeros
    from matplotlib.pyplot import draw
    from matplotlib.table import Table

    num_tran = tran_list.shape[0]
    evt_id_list = tran_list['e_idx']
    evt_t_list = tran_list['e_t']
    evt_s_list = tran_list['e_n']

    state_e_list = state_list['e_idx']
    state_n_list = state_list['n_idx']
    state_val_list = state_list['n_val']

    len_states = len(state_e_list)

    x_lim = cur_ax.get_xlim()
    y_lim = cur_ax.get_ylim()
    x_range = x_lim[1] - x_lim[0]
    y_range = y_lim[1] - y_lim[0]
    col_width = x_range / float(len(node_map))
    row_height = y_range * 0.1 / 2.0
    state_table = Table(cur_ax, bbox=(0, 0, 1, 0.1))
    cell_pos = dict()
    for idx, key_value in enumerate(node_map.items()):
        _key = key_value[0]
        _val = key_value[1]
        cell_pos[_key] = idx
        state_table.add_cell(0,
                             idx,
                             col_width,
                             row_height,
                             text=str(_val),
                             loc='center',
                             edgecolor='none',
                             facecolor='none')
        state_table.add_cell(1,
                             idx,
                             col_width,
                             row_height,
                             text='-',
                             loc='center',
                             edgecolor='none',
                             facecolor='none')
    cur_ax.add_table(state_table)
    draw()
    state_cells = state_table.get_celld()

    t_delta_last = 0.0
    evt_mask = zeros(len(node_patches), dtype=int)
    patch_list = node_patches.values()
    last_time = -1.0
    new_state_idx = 0
    new_state_e = state_e_list[new_state_idx]
    for idx in range(num_tran):
        evt_id = evt_id_list[idx]
        evt_t = evt_t_list[idx]
        if last_time < evt_t:
            draw()
            for pat in patch_list:
                _alpha = pat.get_alpha()
                if _alpha > (__disp_defs__.PATCH_MIN_ALPHA +
                             __disp_defs__.PATCH_ALPHA_DECAY):
                    pat.set_alpha(_alpha - __disp_defs__.PATCH_ALPHA_DECAY)
            t_delta = evt_t - t_delta_last
            t_delta_last = evt_t
            time.sleep(t_delta * time_mult)
            last_time = evt_t
        evt_s = evt_s_list[idx]
        current_iter = evt_mask[evt_s]
        current_patch = node_patches[evt_s]
        current_patch.set_alpha(1.0)
        current_node_text = node_text[evt_s]
        current_node_text.set_text(current_iter)
        if current_patch.get_axes() is None:
            cur_ax.add_patch(current_patch)
            cur_ax.add_artist(current_node_text)
        current_iter += 1
        evt_mask[evt_s] = current_iter

        if evt_id == new_state_e:
            node_idx = state_n_list[new_state_idx]
            node_val = state_val_list[new_state_idx]
            cell_idx = cell_pos[node_idx]
            state_cells[(1, cell_idx)].get_text().set_text(str(node_val))
            print("t=%g, node: %s, val: %s" %
                  (evt_t, node_map[node_idx], node_val))
            new_state_idx += 1
            if new_state_idx < len_states:
                new_state_e = state_e_list[new_state_idx]
            else:
                break
Exemple #9
0
class TableAxis(object):
    def __init__(self, ax, df, colors):
        self._ax = ax
        self._df = df
        self._colors = colors

        # Create matplotlib's Table object
        self._tb = Table(self._ax, bbox=[0, 0, 1, 1])
        self._tb.auto_set_font_size(False)
        self._ax.add_table(self._tb)

        self._calculate_cell_size()
        self._add_cells()
        self._remove_ticks()

    def _calculate_cell_size(self):
        # The number of total rows and columns of the table
        self._nrows = self._df.shape[0]
        self._ncols = self._df.shape[1]

        # Calculate the width and height of a single cell in the table
        L = 1.0
        H = 1.0
        self._w_cell = L / self._ncols
        self._h_cell = H / self._nrows

    def _add_cells(self):
        raise NotImplementedError()

    def _remove_ticks(self):
        self._ax.get_xaxis().set_ticks([])
        self._ax.get_yaxis().set_ticks([])

    def add_row_labels(self):
        """Add row labels using y-axis
        """
        ylabels = list(self._df.index)
        self._ax.yaxis.tick_left()

        yticks = [self._h_cell / 2.0]  # The position of the first label

        # The row labels of condition subtable
        for j in range(1, self._nrows):
            yticks.append(yticks[j - 1] + self._h_cell)

        self._ax.set_yticks(yticks)
        self._ax.set_yticklabels(ylabels, minor=False)
        self._ax.tick_params(axis='y', which='major', pad=3)

        # Hide the small bars of ticks
        for tick in self._ax.yaxis.get_major_ticks():
            tick.tick1On = False
            tick.tick2On = False

    # end of def

    def add_column_labels(self):
        """
        Add column labels using x-axis
        """
        xlabels = list(self._df.columns)
        xticks = [self._w_cell / 2.0]  # The position of the first label
        # The column labels of condition subtable
        for j in range(1, self._ncols):
            xticks.append(xticks[j - 1] + self._w_cell)

        self._ax.xaxis.set_ticks_position('none')

        # # Hide the small bars of ticks
        # for tick in self._ax.xaxis.get_major_ticks():
        #     tick.tick1On = False
        #     tick.tick2On = False

        self._ax.set_xticks(xticks)
        self._ax.set_xticklabels(xlabels, rotation=90, minor=False)
        self._ax.tick_params(axis='x', which='major', pad=-2)

    # end of def

    @property
    def fontsize(self):
        return self._table_fontsize

    @fontsize.setter
    def fontsize(self, val):
        """
        Resize text fonts
        """
        self._table_fontsize = val
        self._tb.set_fontsize(val)

    @property
    def linewidth(self):
        return self._linewidth

    @linewidth.setter
    def linewidth(self, val):
        """Adjust the width of table lines
        """
        self._linewidth = val
        celld = self._tb.get_celld()
        for (i, j), cell in celld.items():
            cell.set_linewidth(self._linewidth)
Exemple #10
0
def draw_grid_world_policy_image(policy,
                                 filename,
                                 GRID_HEIGHT,
                                 GRID_WIDTH,
                                 ACTION_SYMBOLS,
                                 TERMINAL_STATES=None):
    action_str_values = []
    for i in range(GRID_HEIGHT):
        action_str_values.append([])
        for j in range(GRID_WIDTH):
            if TERMINAL_STATES and (i, j) in TERMINAL_STATES:
                continue
            str_values = []
            actions, probs = policy[(i, j)]
            for action in actions:
                str_values.append("{0} ({1})".format(
                    ACTION_SYMBOLS[action], np.round(probs[action],
                                                     decimals=3)))
            action_str_values[i].append("\n".join(str_values))

    # 축 표시 제거, 크기 조절 등 이미지 그리기 이전 설정 작업
    fig, ax = plt.subplots()
    ax.set_axis_off()
    table = Table(ax, bbox=[0, 0, 1, 1])

    nrows, ncols = GRID_HEIGHT, GRID_WIDTH
    width, height = 1.0 / ncols, 1.0 / nrows

    # 렌더링 할 이미지에 표 셀과 해당 값 추가
    for i in range(GRID_HEIGHT):
        for j in range(GRID_WIDTH):
            if TERMINAL_STATES and (i, j) in TERMINAL_STATES:
                continue
            table.add_cell(i,
                           j,
                           width,
                           height,
                           text=action_str_values[i][j],
                           loc='center',
                           facecolor='white')

    # 행, 열 라벨 추가
    for i in range(len(action_str_values)):
        table.add_cell(i,
                       -1,
                       width,
                       height,
                       text=i + 1,
                       loc='right',
                       edgecolor='none',
                       facecolor='none')
        table.add_cell(-1,
                       i,
                       width,
                       height / 2,
                       text=i + 1,
                       loc='center',
                       edgecolor='none',
                       facecolor='none')

    for key, cell in table.get_celld().items():
        cell.get_text().set_fontsize(10)

    ax.add_table(table)

    plt.savefig(filename)
    plt.close()