Beispiel #1
0
 def draw_iqa(img, q, target_a, pred_a):
     fig, ax = tfplot.subplots(figsize=(6, 6))
     ax.imshow(img)
     ax.set_title(question2str(q))
     ax.set_xlabel(
         answer2str(target_a) + answer2str(pred_a, 'Predicted'))
     return fig
Beispiel #2
0
def draw_trajectory_multiple(gt, *pred_list):
    time_step, coord_size = gt.shape
    try:
        ay_ax_list = [(a[:, 2], a[:, 3]) for a in [gt] + list(pred_list)]
    except:
        ay_ax_list = [(a[:, 0], a[:, 1]) for a in [gt] + list(pred_list)]

    N = len(ay_ax_list)
    sqrtN = int(N**0.5 + 1e-9)
    H, W = sqrtN, sqrtN
    if H * W < N: W += 1

    fig, axes = tfplot.subplots(H, W, figsize=(12, 12), squeeze=False)
    for h in range(H):
        for w in range(W):
            k = h * H + w
            ax = axes[h, w]
            ax.axis([0, 1, 0, 1])
            ax.scatter(ay_ax_list[k][1],
                       ay_ax_list[k][0],
                       c=range(time_step),
                       cmap='jet')
            if k == 0:
                ax.set_title("GT")
                ax.set_axis_bgcolor('gray')
            else:
                ax.set_title("pred %d" % k)
    return fig
Beispiel #3
0
def waveform_plot(waveform, delta_t):
    fig, ax = tfplot.subplots(figsize=(3, 3))
    times = np.arange(len(waveform)) * delta_t
    ax.plot(times, waveform)
    ax.set_ylabel('signal')
    ax.set_xlabel('time')

    return fig
Beispiel #4
0
def _create_bar_stats_figure(labels, probs):
    labels = np.array([os.fsdecode(i) for i in labels])
    fig, ax = tfplot.subplots()
    ax.bar(np.arange(probs.size), probs, 0.35)
    ax.set_xticks(np.arange(labels.size))
    ax.set_xticklabels(labels, rotation=+90, ha='center')
    fig.set_tight_layout(True)
    return fig
Beispiel #5
0
 def draw_act_hist(h, grid_shape):
     fig, ax = tfplot.subplots(figsize=(4, 4))
     h = np.reshape(h, [grid_shape[0] * grid_shape[1]])
     hist, bins = np.histogram(h)
     ax.bar(bins[:-1],
            hist.astype(np.float32) / hist.sum(),
            width=(bins[1] - bins[0]),
            color='blue')
     ax.plot(x='Activation values', y='Probability')
     return fig
Beispiel #6
0
def _create_wrong_example_plot(wis, whs, winst, wtypes, wids):
    winst = [os.fsdecode(i) for i in winst]
    wtypes = [os.fsdecode(i) for i in wtypes]

    mb = np.random.randint(0, wis.shape[0])

    fig, [ax1, ax2] = tfplot.subplots(2, 1)
    ax1.plot(wis[mb, 0, :, 0])
    ax1.plot(wis[mb, 0, :, 1])
    ax1.set_title(winst[mb] + ' | ' + wtypes[mb] + ' | id: ' + str(wids[mb]))
    ax2.imshow(whs[mb, :, :, 0])
    fig.set_size_inches(8, 8)
    return fig
            def figure_prediction(pred_x, pred_y, test_x, test_y):
                fig, (ax1, ax2) = tfplot.subplots(1, 2, figsize=(8, 4))

                def subfigure(x, y, ax, title):
                    ax.plot(x, y, '.b', alpha=.05)
                    p_line = [max(y), min(y)]
                    ax.plot(p_line, p_line, '--k', alpha=.5)
                    ax.axis([min(x), max(x), min(y), max(y)])
                    ax.set_xlabel("Prediction [eV]")
                    ax.set_title(title)

                subfigure(pred_x, pred_y, ax1, "Batch")
                subfigure(test_x, test_y, ax2, "Test")
                ax2.yaxis.tick_right()
                ax1.set_ylabel("DFT value [eV]")
                return fig
Beispiel #8
0
def draw_trajectory(gt, pred):
    a, b = gt, pred
    time_step, coord_size = a.shape
    try:
        ay, ax = a[:, 2], a[:, 3]  # just draw the first player
        by, bx = b[:, 2], b[:, 3]  # just draw the first player
    except:
        ay, ax = a[:, 0], a[:, 1]
        by, bx = b[:, 0], b[:, 1]

    fig, axes = tfplot.subplots(1, 2, figsize=(8, 4))
    # gt
    axes[0].axis([0, 1, 0, 1])
    axes[0].scatter(ax, ay, c=range(time_step), cmap='jet')
    axes[0].set_title("GT")
    # pred
    axes[1].axis([0, 1, 0, 1])
    axes[1].scatter(bx, by, c=range(time_step), cmap='jet')
    axes[1].set_title("pred")
    return fig
Beispiel #9
0
 def draw_act_hist(h, grid_shape):
     fig, ax = tfplot.subplots(figsize=(4,4))
     h = np.reshape(h, [grid_shape[0]*grid_shape[1]])
     # n, bins, patches = ax.hist(h, 50, normed=1, facecolor='blue', alpha=0.75)
     hist, bins = np.histogram(h)
     ax.bar(bins[:-1], hist.astype(np.float32) / hist.sum(), width=(bins[1]-bins[0]), color='blue')
     ax.plot(x='Activation values', y='Probability')
     # fig.xlabel('Activation values')
     # fig.ylabel('Probability')
     # fig.grid(True)
     # ax.show() 
     """
     fig = plt.figure()
     n, bins, patches = plt.hist(np.reshape(h, [grid_shape[0]*grid_shape[1]]), 50, normed=1, facecolor='blue', alpha=0.75)
     plt.xlabel('Activation values')
     plt.ylabel('Probability')
     plt.grid(True)
     plt.show() 
     """
     return fig
Beispiel #10
0
def heatmap_overlay(data,
                    overlay_image=None,
                    cmap='jet',
                    cbar=False,
                    show_axis=False,
                    alpha=0.5,
                    **kwargs):
    fig, ax = tfplot.subplots(figsize=(5, 4) if cbar else (4, 4))
    fig.subplots_adjust(0, 0, 1, 1)  # use tight layout (no margins)
    ax.axis('off')

    if overlay_image is None: alpha = 1.0
    sns.heatmap(data, ax=ax, alpha=alpha, cmap=cmap, cbar=cbar, **kwargs)

    if overlay_image is not None:
        h, w = data.shape
        ax.imshow(overlay_image, extent=[0, h, 0, w])

    if show_axis:
        ax.axis('on')
        fig.subplots_adjust(left=0.1, bottom=0.1, right=0.95, top=0.95)
    return fig
Beispiel #11
0
        def draw_iqa(img, q, target_a, pred_a, weights):
            d = self.d
            H, W = img.shape[:2]

            weights = weights.reshape(d*d, d*d)
            weights_a2b = np.mean(weights, axis=1).reshape(4,4)
            weights_b2a = np.mean(np.transpose(weights), axis=1).reshape(4,4)
            mean_w = (weights_a2b + weights_b2a) / 2
            mean_w = mean_w / np.max(mean_w)



            # print(mean_w.shape, img.shape)
            # print("===========")

            fig, ax = tfplot.subplots(figsize=(6, 6))
            ax.imshow(img, extent=[0,H,0,W])
            mid = ax.imshow(mean_w, cmap='jet',
                      alpha=0.5, extent=[0, H, 0, W])
            fig.colorbar(mid)
            ax.set_title(question2str(q))
            ax.set_xlabel(answer2str(target_a)+answer2str(pred_a, 'Predicted'))
            return fig
Beispiel #12
0
def calc_heat_map(data, cmap='jet'):
    fig, ax = tfplot.subplots()
    ax.imshow(data, cmap=cmap)
    return fig
Beispiel #13
0
def _create_confusion_image(confusion_data):
    fig, ax = tfplot.subplots()
    ax.imshow(confusion_data)
    fig.set_size_inches(8, 8)
    return fig
Beispiel #14
0
def figure_heatmap(data, cmap=matplotlib.cm.Set1):# Hier Farbe anpassbar
    fig, ax = tfplot.subplots()
    norm = matplotlib.colors.BoundaryNorm(np.arange(-0.5,8.5,1), cmap.N)
    ax.imshow(data, cmap=cmap, norm=norm)
    return fig
 def draw_act_vis(h, grid_shape):
     fig, ax = tfplot.subplots(figsize=(4, 4))
     i = ax.imshow(h.reshape(grid_shape))
     fig.colorbar(i)
     return fig
Beispiel #16
0
 def figure_attention(activation):
     fig, ax = tfp.subplots()
     im = ax.imshow(activation, cmap='jet')
     fig.colorbar(im)
     return fig
def plot_num_incorrect_edges_per_graph(n_nodes, n_incorrect):
    fig, ax = tfplot.subplots(figsize=(12, 9))
    img = ax.scatter(n_nodes, n_incorrect, c='blue')
    return img.figure
Beispiel #18
0
def figure_attention(attention):
    fig, ax = tfplot.subplots(figsize=(4, 3))
    im = ax.imshow(attention)

    return fig