示例#1
0
def make_posemap(img, pose, video_num, img_size, f_num, gauss_size):
    posemap = pose_map(pose, img_size, gauss_size)
    posemap = np.uint8(255 * posemap)
    posemap = cv2.applyColorMap(posemap, cv2.COLORMAP_JET)
    posemap = cv2.cvtColor(posemap, cv2.COLOR_BGR2RGB)
    posemap = posemap / 255
    s_img = posemap * 1.0 + img

    fig = plt.figure(figsize=(16, 10))
    gs = GridSpec(4, 5, left=0.13, right=0.9)
    gs.update(wspace=-0.01)
    gs_1 = GridSpecFromSubplotSpec(nrows=4, ncols=3, subplot_spec=gs[0:4, 0:3])
    fig.add_subplot(gs_1[:, :])
    delete_line()
    plt.imshow(s_img)
    gs_2 = GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=gs[0:2, 3:5])
    fig.add_subplot(gs_2[:, :])
    delete_line()
    plt.imshow(img)
    gs_3 = GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=gs[2:4, 3:5])
    fig.add_subplot(gs_3[:, :])
    delete_line()
    plt.imshow(posemap, cmap='jet')
    plt.clim(0, 1)
    plt.colorbar()
    SAVE_PATH = "../../../demo/images/posemap"
    if not os.path.exists(os.path.join(SAVE_PATH, "posemap_" + video_num)):
        os.makedirs(os.path.join(SAVE_PATH, "posemap_" + video_num))

    plt.savefig(
        os.path.join(SAVE_PATH, "posemap_" + video_num,
                     str(f_num).zfill(5) + ".png"))
    plt.close()
示例#2
0
def plot_all_measures(data, meta, kind='mouse', title=''):
    """
    meta should have: 'n_iters', 'n_pc', 'f_range', 'subsetsizes', 'pc_range'
    kind: 'mouse' (includes sum & nonsum PSD data), 'mouse_old', or 'sim' (ising or shuffle)
    plot layout:
    [ ES  ]  [ subset size vs ES exponent  ]   [ correlation ]
    [ PSD ]  [ subset size vs PSD exponent ]
    """
    subsetsizes = meta['subsetsizes']
    n_pc = meta['n_pc']
    n = len(subsetsizes)
    dsuffix = '1' if kind == 'mouse' else ''
    ## Plot style
    fig=plt.figure(figsize=(19,8)) 
    subset_fractions = np.linspace(0,1,n)
    cmap = plt.cm.cool(subset_fractions)
    plt.rcParams["axes.prop_cycle"] = plt.cycler("color", cmap)

    gs = GridSpec(2,3, width_ratios=[1,1,4], wspace=.8) #
    gs1 = GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[:,:2], hspace=0.5, wspace=0.3)
    gs2 = GridSpecFromSubplotSpec(2, 3, subplot_spec=gs[:,2])

    ## Spectra
    for (ip, spec), labs in zip(enumerate(['eigs', 'pows']), [['log ES', 'PC dimension', 'Variance'],
                                                            ['log PSD', 'Frequency (Hz)', 'Power']]):
        ax = fig.add_subplot(gs1[ip,0])
        for i, n_i in enumerate(subsetsizes):
            if isinstance(n_pc, int) and n_pc < n_i:
                n_pc_curr = n_pc
            elif isinstance(n_pc, float):
                n_pc_curr = int(n_pc*n_i)

            xvals = np.arange(1,n_pc_curr+1)/n_pc_curr if spec == 'eigs'\
            else np.arange(0,61/120, 1/120)
            # Eigenspectrum
            ax.plot(xvals, data[spec][i]) #KEEP THIS LINE: proportion of PCs
            logaxes()
            pltlabel(*labs)

    ## Exponent distributions
    for (ip, exp), labs in zip(enumerate(['pca_m', 'ft_m'+dsuffix]), [['ES exponent \n at each subset size', '', 'Exponent'],
                                                            ['PSD exponent \n at each subset size', '', 'Exponent']]):
        ax = fig.add_subplot(gs1[ip,1])
        exp_plot(data, exp, ax=ax)
        pltlabel(*labs)
        
    ## colorbar for first two cols
    cax = fig.add_axes([0.45, 0.3, 0.01, 0.35])
    hexes = [mpl.colors.rgb2hex(c) for c in cmap] # VERY hacky way of getting hex values of cmap cool
    solo_colorbar([hexes[0], hexes[-1]], subset_fractions, 'fraction sampled', 
                orientation='vertical', cax=cax)

    ## Interspec Correlation
    ax = fig.add_subplot(gs2[:])
    corr_plot(data['pearson_corr'+dsuffix], 'Pearson',
            p_vals=data['pearson_p'+dsuffix], ax=ax)

    plt.suptitle(title)
    plt.tight_layout()
示例#3
0
def kriging_plot(dataset,
                 xy,
                 v,
                 param,
                 selected_model,
                 save_name=None,
                 front_color="rb",
                 title=None,
                 show=True):
    fig = plt.figure(figsize=(9, 8))
    gs_master = GridSpec(nrows=1, ncols=2, width_ratios=[8, 1])
    gs_1 = GridSpecFromSubplotSpec(nrows=1,
                                   ncols=1,
                                   subplot_spec=gs_master[0, 0])
    ax = fig.add_subplot(gs_1[:, :])
    gs_2 = GridSpecFromSubplotSpec(nrows=1,
                                   ncols=1,
                                   subplot_spec=gs_master[0, 1])
    ax_c = fig.add_subplot(gs_2[:, :])

    for i in dataset["back"].keys():
        poly = plt.Polygon(dataset["back"][i], fill=False)
        ax.add_patch(poly)

    if front_color == "rb":
        color_dict, color_bar, min_y, max_y, dy = rb(dataset["data"],
                                                     dataset["double"])

    else:
        print("???")
        sys.exit()

    for i in dataset["front"].keys():
        poly = plt.Polygon(dataset["front"][i], fc=color_dict[i])
        ax.add_patch(poly)
    x_back = [x for i in dataset["back"].keys() for x, _ in dataset["back"][i]]
    y_back = [y for i in dataset["back"].keys() for _, y in dataset["back"][i]]
    xmin = min(x_back)
    xmax = max(x_back)
    ymin = min(y_back)
    ymax = max(y_back)
    ax.set_xlim([xmin, xmax])
    ax.set_ylim([ymin, ymax])
    if title:
        ax.set_title(title)

    for i, c in zip(np.linspace(min_y, max_y, 256), color_bar):
        poly = plt.Polygon(((0, i), (1, i), (1, i + dy), (0, i + dy)), fc=c)
        ax_c.add_patch(poly)
    ax_c.set_xlim([0, 1])
    ax_c.tick_params(labelbottom=False)
    ax_c.set_ylim([min_y, max_y])

    if save_name:
        plt.savefig(save_name)
    if show:
        plt.show()
    return
示例#4
0
def heatmap(dataset, save_name=None, front_color="rb", back_color="k", title=None, show=True, school=False):
    fig = plt.figure(figsize=(9,8))
    gs_master = GridSpec(nrows=1, ncols=2, width_ratios=[8, 1])
    gs_1 = GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=gs_master[0, 0])
    ax = fig.add_subplot(gs_1[:, :])
    gs_2 = GridSpecFromSubplotSpec(nrows=1, ncols=1, subplot_spec=gs_master[0, 1])
    ax_c = fig.add_subplot(gs_2[:, :])

    for i in dataset["back"].keys():
        poly = plt.Polygon(dataset["back"][i], fc=back_color)
        ax.add_patch(poly)

    if front_color == "rb":
        color_dict, color_bar, min_y, max_y, dy = rb(dataset["data"], dataset["double"])

    else:
        print("???")
        sys.exit()

    for i in dataset["front"].keys():    
        poly = plt.Polygon(dataset["front"][i], fc=color_dict[i])
        ax.add_patch(poly)
    x_back = [x for i in dataset["back"].keys() for x,_ in dataset["back"][i]]
    y_back = [y for i in dataset["back"].keys() for _,y in dataset["back"][i]]
    xmin = min(x_back)
    xmax = max(x_back)
    ymin = min(y_back)
    ymax = max(y_back)
    ax.set_xlim([xmin,xmax])
    ax.set_ylim([ymin,ymax])
    if title:
        ax.set_title(title)

    for i,c in zip(np.linspace(min_y, max_y, 256), color_bar):
        poly = plt.Polygon(((0,i),(1,i),(1,i+dy),(0,i+dy)), fc=c)
        ax_c.add_patch(poly)
    ax_c.set_xlim([0,1])
    ax_c.tick_params(labelbottom=False)
    ax_c.set_ylim([min_y, max_y])
    
    if school:
        sf2 = load_toyama_second_pos().values()
        x = []
        y = []
        for i in sf2:
            x.append(i[0])
            y.append(i[1])
        ax.plot(x, y, marker="o", linestyle="None")
    
    if save_name:
        plt.savefig(save_name)
    if show:
        plt.show()
示例#5
0
def density_projection(dens_x):
    from ofdft_ml.statslib.pca import PrincipalComponentAnalysis
    from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec 
    
    pca = PrincipalComponentAnalysis(8)
    dens_t = pca.fit_transform(dens_x)
    tr_mat = pca.tr_mat_
    # plots
    ## density data distribution
    dens_fig = plt.figure(0, figsize=(10, 5))
    dens_gs = GridSpec(1, 2)
    ### scatter plot
    gs_0 = GridSpecFromSubplotSpec(1, 1, dens_gs[0])
    ax_0 = dens_fig.add_subplot(gs_0[0])
    ax_0.scatter(dens_t[:, 0], dens_t[:, 1])
    ax_0.set_xlabel('principal #1')
    ax_0.set_ylabel('principal #2')
    ### histogram plot
    _, bins_p0 = np.histogram(dens_t[:, 0], bins=20, density=False)
    x_min, x_max = np.amin(dens_t[:, 0]), np.amax(dens_t[:, 0])
    gs_1 = GridSpecFromSubplotSpec(4, 2, dens_gs[1], wspace=0.1)
    dens_axes = [dens_fig.add_subplot(gs_1[i, j]) for i in [0, 1, 2, 3] for j in [0, 1]]
    for i in range(8):
        n, bins_edge, patches = dens_axes[i].hist(dens_t[:, i], bins=bins_p0)
        y_max = np.amax(n)
        dens_axes[i].xaxis.set_major_locator(FixedLocator([x_min, x_max]))
        dens_axes[i].xaxis.set_major_formatter(NullFormatter())
        dens_axes[i].yaxis.set_major_formatter(NullFormatter())
        dens_axes[i].set_xlim([x_min-0.5, x_max+0.5])
        dens_axes[i].set_ylim([0, y_max])
    dens_axes[6].xaxis.set_major_formatter(FixedFormatter(['%.1f' %(x_min), '%.1f' %(x_max)]))
    dens_axes[7].xaxis.set_major_formatter(FixedFormatter(['%.1f' %(x_min), '%.1f' %(x_max)]))
    dens_fig.text(0.7, 0.04, 'range')
    dens_fig.text(0.51, 0.5, 'fraction', va='center', rotation='vertical')
    dens_fig.savefig('principal_components.png')
    ## transfer matrix plot
    X = np.linspace(0, 1, dens_x.shape[1])
    mat_fig = plt.figure(1, figsize=(5, 5))
    mat_gs = GridSpec(2, 2, wspace=0.2)
    mat_axes = [mat_fig.add_subplot(mat_gs[i, j]) for i in [0, 1] for j in [0, 1]]
    for i in range(4):
        mat_axes[i].plot(X, tr_mat[:, i], label='#%s' %(i))
        vec_min, vec_max = np.amin(tr_mat[:, i]), np.amax(tr_mat[:, i])
        mat_axes[i].legend()
        mat_axes[i].xaxis.set_major_locator(FixedLocator([0, 1]))
        mat_axes[i].yaxis.set_major_locator(FixedLocator([vec_min, vec_max]))
        mat_axes[i].xaxis.set_major_formatter(NullFormatter())
        mat_axes[i].yaxis.set_major_formatter(FixedFormatter(['%.2f' %(vec_min), '%.2f' %(vec_max)]))
    mat_axes[2].xaxis.set_major_formatter(FixedFormatter(['0', '1']))
    mat_axes[3].xaxis.set_major_formatter(FixedFormatter(['0', '1']))
    mat_fig.savefig('transfer_matrix.png')
    return 0
def annotated_heatmap(ax, data, colors_x, colors_y,
                      annotate=True, cmap='binary', show_ylabel=True,
                      annotate_size=LEGEND_TEXT_SIZE):
    # This is the function to make annotated heatmaps
    ax.axis('off')
    _gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax, wspace=0, hspace=0,
                                  height_ratios=[49, 1], width_ratios=[1, 49])
    ax_m = ax.figure.add_subplot(_gs[0, 1])
    ax_v = ax.figure.add_subplot(_gs[0, 0])
    ax_h = ax.figure.add_subplot(_gs[1, 1])

    normed = data / data.sum(axis=0)
    normed.fillna(0, inplace=True)
    sns.heatmap(normed, annot=data if annotate else False, fmt='d',
                ax=ax_m, square=False, cmap=cmap, cbar=False,
                xticklabels=1, yticklabels=1,
                annot_kws=dict(fontsize=annotate_size))

    for k, color in enumerate(colors_x):
        ax_h.barh(0, 1, left=k, color=color, height=1)
    for k, color in enumerate(colors_y):
        ax_v.bar(0, 1, bottom=k, color=color, width=1)

    transfer_ticks(ax_m, ax_h, which='x', rotation=90)
    transfer_ticks(ax_m, ax_v, which='y')

    fix_spines(ax_h, [], keep_ticks=True)
    fix_spines(ax_v, [], keep_ticks=True)

    ax_h.set_xlabel('Reference')
    if show_ylabel:
        ax_v.set_ylabel('Prediction')

    return ax
    def set_layout(self):

        fig = plt.figure()
        self.fig = fig

        plt.subplots_adjust(
            left=0.12,  # pos. of subplots left border
            bottom=0.15,  # pos. of subplots bottom border
            right=0.88,  # pos. of subplots right border
            top=0.9,  # pos. of subplots top border
            wspace=0.6,  # horizontal space between supblots
            hspace=0.6,  # vertical space between subplots
        )

        gs00 = GridSpec(nrows=1, ncols=1)

        gsA = GridSpecFromSubplotSpec(2,
                                      3,
                                      subplot_spec=gs00[0, 0],
                                      wspace=0.1,
                                      hspace=0.1)
        axA1 = fig.add_subplot(gsA[:, 0])
        axA2 = fig.add_subplot(gsA[0, 1:])
        axA3 = fig.add_subplot(gsA[1, 1:])
        self.subfig_1 = [axA1]
        self.subfig_2 = [axA2, axA3]
示例#8
0
 def plot_slides(self,
                 coronal: int,
                 sagittal: int,
                 ss: SubplotSpec = None,
                 fig: Any = None,
                 return_figure: Any = False) -> Any:
     if self.is_label and self.reference:
         return self.colored.plot_slides(coronal=coronal,
                                         sagittal=sagittal,
                                         contour=False,
                                         ss=ss,
                                         fig=fig,
                                         return_figure=return_figure)
     else:
         if fig is None:
             fig = plt.gcf()
         if ss is None:
             ss = plt.GridSpec(1, 1)[0]
         if return_figure:
             fig.clear()
         gs = GridSpecFromSubplotSpec(1, 2, subplot_spec=ss)
         ax = fig.add_subplot(gs[0])
         ax.imshow(self[coronal, :, :], cmap="gray")
         ax.axvline(sagittal, c='y', lw=1)
         ax = fig.add_subplot(gs[1])
         ax.imshow(self[:, :, sagittal].T, cmap="gray")
         ax.axvline(coronal, c='y', lw=1)
         if return_figure:
             return fig
示例#9
0
 def plot_image_grid(self, cell, metric_name, metric):
     imgs = metric['data']
     imgs = np.clip(imgs, 0., 1.)
     n_images = len(imgs)
     inner_grid_width = int(np.sqrt(n_images))
     inner_grid = GridSpecFromSubplotSpec(inner_grid_width,
                                          inner_grid_width,
                                          cell,
                                          wspace=0.1,
                                          hspace=0.1)
     is_first_one = True
     for i in range(n_images):
         inner_ax = plt.subplot(inner_grid[i])
         if is_first_one:
             inner_ax.set_title(metric_name)
             is_first_one = False
         if imgs.ndim == 4:
             inner_ax.imshow(imgs[i, :, :, :],
                             interpolation='none',
                             vmin=0.0,
                             vmax=1.0)
         else:
             inner_ax.imshow(imgs[i, :, :],
                             cmap='gray',
                             interpolation='none',
                             vmin=0.0,
                             vmax=1.0)
         inner_ax.axis('off')
示例#10
0
def plot_lda_topics(documents,
                    nrows,
                    ncols,
                    with_colorbar=True,
                    topic_mixtures=None,
                    cmap='Viridis',
                    dpi=160):
    fig = plt.figure()
    gs = GridSpec(nrows, ncols)

    vmin, vmax = (0, documents.max())

    for i in range(nrows):
        for j in range(ncols):
            index = i * ncols + j
            gsi = GridSpecFromSubplotSpec(6, 5, subplot_spec=gs[i, j])
            _document_with_topic(fig,
                                 gsi,
                                 index,
                                 documents[index],
                                 topic_mixture=topic_mixtures[index],
                                 vmin=vmin,
                                 vmax=vmax)

    return fig
示例#11
0
    def plot(self, **pca_plotting_kwargs):
        if not self.has_been_fit:
            self.fit()

        gs_x = 18
        gs_y = 12

        ax = None if 'ax' not in pca_plotting_kwargs \
            else pca_plotting_kwargs['ax']

        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=(18, 8))
            gs = GridSpec(gs_x, gs_y)
        else:
            gs = GridSpecFromSubplotSpec(gs_x, gs_y, ax.get_subplotspec())

        ax_scores = plt.subplot(gs[5:10, :2])
        ax_scores.set_xlabel("Feature Importance")
        ax_scores.set_ylabel("Density Estimate")

        if 'show_vectors' not in pca_plotting_kwargs:
            pca_plotting_kwargs['show_vectors'] = True

        ax_pca = plt.subplot(gs[:, 2:])
        pca_plotting_kwargs['ax'] = ax_pca

        self.plot_scores(ax=ax_scores)
        pcaviz = self.do_pca(**pca_plotting_kwargs)
        plt.tight_layout()

        return pcaviz
示例#12
0
    def _draw_MR(self, p, dp=10, subplot_spec=None):
        """Draw the MR curve with dp % variation in the parameter p."""
        u = constants
        p = p.split()[0]
        parameter_variation = (p, dp/100.0)

        if subplot_spec:
            gs = GridSpecFromSubplotSpec(1, 2, subplot_spec=subplot_spec)
        else:
            gs = GridSpec(1, 2)
            plt.figure(figsize=(10, 5))
            
        plt.subplot(gs[0])
        derivatives = self.derivatives._asdict()
        ip = self.param_names.index(parameter_variation[0])
        dp_ = parameter_variation[1]*self.params[ip]
        dR = derivatives['R'][:, ip] * dp_
        dM = derivatives['M'][:, ip] * dp_
        plt.plot(self.R/u.km, self.M/u.M0)
        error_ellipse(self.R/u.km, self.M/u.M0, xerr=dR/u.km, yerr=dM/u.M0,
                      alpha=0.2)
        plt.xlabel('R [km]')
        plt.ylabel('M [solar_mass]')

        plt.subplot(gs[1])
        dLambda = self.dLambda[:, ip] * dp
        plt.plot(self.Lambda, self.M/u.M0)
        error_ellipse(self.Lambda, self.M/u.M0, xerr=dLambda, yerr=dM/u.M0,
                      alpha=0.2)
        plt.xlabel('Lambda')
        plt.ylabel('M [solar_mass]')
def barplot(ax, data, colors, colors_bar, labels, xlabel):
    ax.axis('off')
    _gs = GridSpecFromSubplotSpec(1, 2, subplot_spec=ax, wspace=0, hspace=0,
                                  width_ratios=[1, 49])
    ax_m = ax.figure.add_subplot(_gs[0, 1])
    ax_v = ax.figure.add_subplot(_gs[0, 0])

    L = len(labels)
    y = np.arange(L)
    ax_m.barh(y, data, color=colors, height=0.9, zorder=3)

    transfer_ticks(ax_m, ax_v, which='y')

    for k, color in enumerate(colors_bar):
        ax_v.bar(0, 1, bottom=k-0.5, color=color, width=1)

    ax_v.set_yticks(y)
    ax_v.set_yticklabels(labels)

    ax_m.set_xlabel(xlabel)
    ax_m.grid(True, which='major', zorder=0)

    ax_m.set_ylim(0-0.5, L-0.5)
    ax_v.set_ylim(*ax_m.get_ylim())

    fix_spines(ax_m, ['left'], keep_ticks=True)
    fix_spines(ax_v, [], keep_ticks=True)

    ax_m.invert_yaxis()
    ax_v.invert_yaxis()

    return ax
def marginal_heatmap(ax, data_m, data_x, data_y,
                     colors_x, colors_y, colors_x_bar, colors_y_bar,
                     marg_xlabel='Samples per class', marg_ylabel='Pos Pred Value',
                     cmap='binary', annotate=True, annotate_size=LEGEND_TEXT_SIZE):
    # This is the function for figure 1

    ax.axis('off')
    _gs = GridSpecFromSubplotSpec(4, 4, subplot_spec=ax, wspace=0.01, hspace=0.01,
                                  height_ratios=[25, 1, 49, 1],
                                  width_ratios=[1, 49, 1, 25])
    ax_m  = ax.figure.add_subplot(_gs[2, 1])
    ax_h1 = ax.figure.add_subplot(_gs[3, 1])
    ax_h2 = ax.figure.add_subplot(_gs[1, 1], sharex=ax_m)
    ax_x  = ax.figure.add_subplot(_gs[0, 1], sharex=ax_m)
    ax_v1 = ax.figure.add_subplot(_gs[2, 0])
    ax_v2 = ax.figure.add_subplot(_gs[2, 2], sharey=ax_m)
    ax_y  = ax.figure.add_subplot(_gs[2, 3], sharey=ax_m)

    normed = data_m / data_m.sum(axis=0)
    normed.fillna(0, inplace=True)
    sns.heatmap(normed, annot=data_m if annotate else False, fmt='d',
                ax=ax_m, square=False, cmap=cmap, cbar=False,
                xticklabels=1, yticklabels=1,
                annot_kws=dict(fontsize=annotate_size))

    ax_x.bar(np.arange(data_x.shape[0]) + 0.5, data_x,
             width=0.9, color=colors_x)
    ax_y.barh(np.arange(data_y.shape[0]) + 0.5, data_y,
              height=0.9, color=colors_y)

    for k, color in enumerate(colors_x_bar):
        ax_h1.barh(0, 1, left=k, color=color, height=1)
        ax_h2.barh(0, 1, left=k, color=color, height=1)
    for k, color in enumerate(colors_y_bar):
        ax_v1.bar(0, 1, bottom=k, color=color, width=1)
        ax_v2.bar(0, 1, bottom=k, color=color, width=1)
    for (i, ci), (j, cj) in product(enumerate(colors_x_bar), enumerate(colors_y_bar)):
        if ci != cj: continue
        patch = mpatches.Rectangle([i, j], 1, 1, color=ci,
                                   alpha=0.2, lw=0)
        ax_m.add_artist(patch)

    transfer_ticks(ax_m, ax_h1, which='x', rotation=90)
    transfer_ticks(ax_m, ax_v1, which='y')

    fix_spines(ax_h1, [], keep_ticks=True)
    fix_spines(ax_v1, [], keep_ticks=True)
    fix_spines(ax_h2, [], keep_ticks=False)
    fix_spines(ax_v2, [], keep_ticks=False)
    fix_spines(ax_x, ['left'], keep_ticks=False)
    fix_spines(ax_y, ['bottom'], keep_ticks=False)

    ax_h1.set_xlabel('Reference')
    ax_v1.set_ylabel('Prediction')

    ax_x.set_ylabel(marg_xlabel)
    ax_y.set_xlabel(marg_ylabel)

    return ax
示例#15
0
def make_hand_image(img, l_img, r_img, pose, video_num, f_num, cutout_size):
    fig = plt.figure(figsize=(16, 10))
    gs = GridSpec(4, 5, left=0.06, right=1.2)
    gs.update(wspace=0.2)
    gs_1 = GridSpecFromSubplotSpec(nrows=4, ncols=3, subplot_spec=gs[0:4, 0:3])
    fig.add_subplot(gs_1[:, :])
    ax = plt.axes()
    r = patches.Rectangle(
        xy=(int((pose[0]) * 224 / 1080) - 87 - int(cutout_size * 224 / 540),
            int((pose[1]) * 224 / 1080) - int(cutout_size * 224 / 540)),
        width=int(cutout_size * 2 * 224 / 540),
        height=int(cutout_size * 2 * 224 / 540),
        ec='#AD13E5',
        linewidth='4.0',
        fill=False)
    l = patches.Rectangle(
        xy=(int((pose[2]) * 224 / 1080) - 87 - int(cutout_size * 224 / 540),
            int((pose[3]) * 224 / 1080) - int(cutout_size * 224 / 540)),
        width=int(cutout_size * 2 * 224 / 540),
        height=int(cutout_size * 2 * 224 / 540),
        ec='#AD13E5',
        linewidth='4.0',
        fill=False)
    ax.add_patch(r)
    ax.add_patch(l)
    delete_line()
    plt.imshow(img)
    gs_2 = GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=gs[1:3, 0:1])
    fig.add_subplot(gs_2[:, :])
    delete_line()
    plt.imshow(l_img)
    gs_3 = GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=gs[1:3, 3:4])
    fig.add_subplot(gs_3[:, :])
    delete_line()
    plt.imshow(r_img)

    # Make the directory if it doesn't exist.
    SAVE_PATH = "../../../demo/images/hand"
    if not os.path.exists(os.path.join(SAVE_PATH, "hand_" + video_num)):
        os.makedirs(os.path.join(SAVE_PATH, "hand_" + video_num))

    plt.savefig(
        os.path.join(SAVE_PATH, "hand_" + video_num,
                     str(f_num).zfill(5) + ".png"))
    plt.close()
示例#16
0
 def plot(self, subspec):
     gs = GridSpecFromSubplotSpec(1,
                                  self.nchildren,
                                  width_ratios=self.ratios,
                                  subplot_spec=subspec,
                                  **self.spacing)
     for col in range(self.nchildren):
         child = self.children.get(col)
         if child is not None:
             child.plot(gs[0, col])
示例#17
0
 def plot(self, subspec):
     gs = GridSpecFromSubplotSpec(self.nchildren,
                                  1,
                                  height_ratios=self.ratios,
                                  subplot_spec=subspec,
                                  **self.spacing)
     for row in range(self.nchildren):
         child = self.children.get(row)
         if child is not None:
             child.plot(gs[row, 0])
示例#18
0
 def ini_inner_grids(self):
     d_inner_grids = dict()
     for host in self.hosts:
         out_coor = self.d_outer_coor[host]
         d_inner_grids[host] = GridSpecFromSubplotSpec(
             2,
             1,
             subplot_spec=self.outer_grid[out_coor],
             wspace=self.inner_wspace,
             hspace=self.inner_hspace)
     return d_inner_grids
示例#19
0
    def __init__(self, data, X=None, Y=None):
        self.data = data
        self.X = np.arange(self.data.shape[1]) if X is None else X
        self.Y = np.arange(self.data.shape[0]) if Y is None else Y

        vmin = np.min(self.data)
        vmax = np.max(self.data)

        # Create and displays the figure object.
        self.fig = plt.figure(figsize=(8,8), frameon=True, tight_layout=True)

        # Create grid for layout
        grid = GridSpec(4,4)

        self.ax_main = self.fig.add_subplot(grid[0:3,0:3])
        #self.ax_main.autoscale(enable=True, tight=True)
        self.ax_main.autoscale(enable=False)
        self.ax_main.set_xlim(np.min(self.X), np.max(self.X))
        self.ax_main.set_ylim(np.min(self.Y), np.max(self.Y))
        # Use 'auto' to adjust the aspect ratio to fill the figure window, 'equal' to fix it.
        self.ax_main.set_aspect('auto', adjustable='box-forced')

        self.ax_h = self.fig.add_subplot(grid[3,0:3], sharex=self.ax_main)
        self.ax_h.set_axis_bgcolor('0.8')
        self.ax_h.autoscale(False)
        self.ax_h.set_ylim(vmin, vmax)

        self.ax_v = self.fig.add_subplot(grid[0:3,3], sharey=self.ax_main)
        self.ax_v.set_axis_bgcolor('0.8')
        self.ax_v.autoscale(False)
        self.ax_v.set_xlim(vmax, vmin)

        self.prev_pt = None
        self.ax_cb = None

        self.cursor = MultiCursor(self.fig.canvas, (self.ax_main, self.ax_h, self.ax_v),
                                  horizOn=True, vertOn=True, color='white', ls='--', lw=1)
        self.fig.canvas.mpl_connect('button_press_event', self._plot_clicked)

        # Setup control buttons
        btn_grid = GridSpecFromSubplotSpec(4, 1, subplot_spec=grid[3,3])
        self.btn_colorbar = Button(self.fig.add_subplot(btn_grid[2,0]), 'Colorbar')
        self.btn_colorbar.on_clicked(self._plot_colorbar)
        self.btn_reset = Button(self.fig.add_subplot(btn_grid[3,0]), 'Reset')
        self.btn_reset.on_clicked(self._plot_reset)

        # Setup color range sliders
        self.slider_vmin = Slider(self.fig.add_subplot(btn_grid[0,0]), "vmin", vmin, vmax, valinit=vmin)
        self.slider_vmin.on_changed(self._plot_rangechanged)
        self.slider_vmax = Slider(self.fig.add_subplot(btn_grid[1,0]), "vmax", vmin, vmax, valinit=vmax, slidermin=self.slider_vmin)
        self.slider_vmax.on_changed(self._plot_rangechanged)
        self.slider_vmin.slidermax = self.slider_vmax

        self.fig.canvas.draw()
示例#20
0
def make_attention_map(img,heatmap,video_num,f_num,save_name):
    #attention map
    heatmap = heatmap.numpy()
    heatmap = np.average(heatmap,axis=0)
    heatmap = util.normalize_heatmap(heatmap)
    # 元の画像と同じサイズになるようにヒートマップのサイズを変更
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap=heatmap/255
    # heatmap * image
    s_img = heatmap * 0.7 + img
    #plt
    fig=plt.figure(figsize=(16,10))
    gs = GridSpec(4,5,left=0.13,right=0.9) 
    gs.update(wspace=-0.01)
    gs_1 = GridSpecFromSubplotSpec(nrows=4, ncols=3, subplot_spec=gs[0:4, 0:3])
    fig.add_subplot(gs_1[:, :])
    util.delete_line()
    plt.imshow(s_img)
    gs_2 = GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=gs[0:2, 3:5])
    fig.add_subplot(gs_2[:,:])
    util.delete_line()
    plt.imshow(img)
    gs_3 = GridSpecFromSubplotSpec(nrows=2, ncols=2, subplot_spec=gs[2:4, 3:5])
    fig.add_subplot(gs_3[:,:])
    util.delete_line()
    plt.imshow(heatmap,cmap='jet')
    plt.clim(0,1)
    plt.colorbar()


    # Make the directory if it doesn't exist.
    SAVE_PATH = "../../../demo/images/attention"
    if not os.path.exists(os.path.join(SAVE_PATH,save_name+"_"+video_num)):
        os.makedirs(os.path.join(SAVE_PATH,save_name+"_"+video_num))
      
    plt.savefig(os.path.join(SAVE_PATH,save_name+"_"+video_num,str(f_num).zfill(5)+".png"))
    plt.close()
示例#21
0
def visualize_seir_computation(results: pd.DataFrame,
                               compartments: List[Any],
                               show_individual_compartments=False):
    """Visualizes the SEIR computation"""

    if show_individual_compartments:
        w, h = plt.figaspect(2)
        fig = plt.figure(figsize=(w, h))

        gs = GridSpec(1, 2, fig, width_ratios=[5, 1])
        gsp = GridSpecFromSubplotSpec(4, 1, gs[0], hspace=0)

        ax = fig.add_subplot(gsp[0])
        _plot_compartment_subplot(ax, 'susceptible', results)

        ax = fig.add_subplot(gsp[1], sharex=ax)
        _plot_compartment_subplot(ax, 'exposed', results)

        ax = fig.add_subplot(gsp[2], sharex=ax)
        _plot_compartment_subplot(ax, 'infected (active)', results)

        ax = fig.add_subplot(gsp[3], sharex=ax)
        lines = _plot_compartment_subplot(ax, 'deaths', results)

        ax.yaxis.set_major_formatter(EngFormatter())
        ax = fig.add_subplot(gs[1])
        ax.legend(lines, compartments)
        ax.set_xticks(())
        ax.set_yticks(())
        ax.set_axis_off()
        fig.tight_layout()

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(results['time'], results['exposed'], label='exposed')
    ax.plot(results['time'],
            results['infected (active)'],
            label='infected (active)')

    ax.plot(results['time'],
            results['hospitalized (active)'],
            label='hospitalized (active)')

    ax.plot(results['time'], results['in ICU'], label='in ICU')

    ax.plot(results['time'], results['deaths'], label='deaths', color='k')
    ax.legend()
    ax.set_xlabel('time (days)')
    ax.set_ylabel('# of people')
    ax.yaxis.set_major_formatter(EngFormatter())
    fig.tight_layout()
    plt.show()
def subtype_combined_barplot(ax, data, colors, colors_bar, labels, bar_labels, xlabel, bar_text=True,
                             show_ylabels=True):
    ax.axis('off')
    _gs = GridSpecFromSubplotSpec(1, 2, subplot_spec=ax, wspace=0, hspace=0,
                                  width_ratios=[1, 49])
    ax_m = ax.figure.add_subplot(_gs[0, 1])
    ax_v = ax.figure.add_subplot(_gs[0, 0])

    L = len(bar_labels)
    y = np.arange(L)
    ax_m.barh(y, data, color=colors, height=0.9, zorder=3)

    transfer_ticks(ax_m, ax_v, which='y')

    for k, color in enumerate(colors_bar):
        ax_v.bar(0, 1, bottom=k-0.5, color=color, width=1)

    if bar_text:
        bars = ax_m.patches
        for k, (color, bar, label) in enumerate(zip(colors, bars, bar_labels)):
            bar_width = bar.get_width()
            lum = sns.utils.relative_luminance(color)
            text_color = ".15" if (lum > .308) or (bar_width < 0.05) else "w"

            t = ax_m.text(0.02, k, label, color=text_color, va='center', size=TICK_TEXT_SIZE,
                          clip_on=True, wrap=True)
            if bar_width > 0.1:
                t.set_clip_path(bar)

    label_counts = pd.value_counts(labels).sort_index()
    ax_v.set_yticks(label_counts.cumsum() - 2)
    ax_v.set_yticklabels(label_counts.index)

    ax_m.set_xlabel(xlabel)
    if data.max() < 1.01:
        ax_m.set_xlim(0, 1.01)
        ax_m.set_xticks(np.arange(0, 1.1, 0.2))
    ax_m.grid(True, which='major', zorder=0)

    ax_m.set_ylim(0-0.5, L-0.5)
    ax_v.set_ylim(*ax_m.get_ylim())

    fix_spines(ax_m, ['left'], keep_ticks=True)
    fix_spines(ax_v, ['left'], keep_ticks=True)

    ax_m.invert_yaxis()
    ax_v.invert_yaxis()

    if not show_ylabels:
        ax_v.set_yticklabels([])

    return ax
示例#23
0
    def set_layout(self):

        fig = plt.figure()
        self.fig = fig

        plt.subplots_adjust(left=0.13,
                            bottom=0.08,
                            right=0.88,
                            top=0.93,
                            wspace=.5,
                            hspace=2.4)

        gs00 = GridSpec(nrows=11, ncols=1)
        gsA = GridSpecFromSubplotSpec(3,
                                      3,
                                      subplot_spec=gs00[0:5, 0],
                                      wspace=0.07,
                                      hspace=0.07)
        axA1 = fig.add_subplot(gsA[1:, 0])
        axA2 = fig.add_subplot(gsA[1:, 1:])
        axB1 = fig.add_subplot(gsA[0, 0])
        axB2 = fig.add_subplot(gsA[0, 1:])
        self.subfig_1 = [axA1, axA2, axB1, axB2]

        gsC = GridSpecFromSubplotSpec(5,
                                      1,
                                      subplot_spec=gs00[5:8, 0],
                                      wspace=0.1,
                                      hspace=1.)
        axC = fig.add_subplot(gsC[0:4, 0])
        self.subfig_2 = axC

        gsD = GridSpecFromSubplotSpec(5,
                                      1,
                                      subplot_spec=gs00[8:, 0],
                                      wspace=0.1,
                                      hspace=1.)
        axD = fig.add_subplot(gsD[0:4, 0])
        self.subfig_3 = axD
示例#24
0
    def plot_slides(self,
                    coronal: int,
                    sagittal: int,
                    contour: bool = False,
                    ss: SubplotSpec = None,
                    fig: Any = None,
                    return_figure: bool = False) -> Any:
        if fig is None:
            fig = plt.gcf()
        if ss is None:
            ss = plt.GridSpec(1, 1)[0]
        if return_figure:
            fig.clear()
        gs = GridSpecFromSubplotSpec(1, 2, subplot_spec=ss)
        ax = fig.add_subplot(gs[0])
        coronal_section = self[coronal, :, :]
        ax.imshow(coronal_section)
        ax.axvline(sagittal, c='y', lw=1)

        if contour:
            coronal_section = self.vol_data[coronal, :, :]
            for one_hot in one_hot_encoding(coronal_section):
                ith_contours = find_contours(one_hot, 0.5)
                for n_ith_contour in ith_contours:
                    ax.plot(n_ith_contour[:, 1],
                            n_ith_contour[:, 0],
                            lw=0.8,
                            color="w",
                            zorder=1000)
        ax = fig.add_subplot(gs[1])
        sagittal_section = self[:, :, sagittal]
        ax.imshow(np.transpose(sagittal_section, (1, 0, 2)))
        ax.axvline(coronal, c='y', lw=1)

        if contour:
            sagittal_section = self.vol_data[:, :, sagittal]
            for one_hot in one_hot_encoding(sagittal_section):
                ith_contours = find_contours(one_hot, 0.5)
                for n_ith_contour in ith_contours:
                    ax.plot(n_ith_contour[:, 0],
                            n_ith_contour[:, 1],
                            lw=0.8,
                            color="w",
                            zorder=1000)
        if return_figure:
            return fig
示例#25
0
    def __call__(self,
                 ax=None,
                 title='',
                 show_point_labels=False,
                 show_vectors=True,
                 show_vector_labels=True,
                 markersize=10,
                 legend=True):
        gs_x = 14
        gs_y = 12

        if ax is None:
            self.reduced_fig, ax = plt.subplots(1, 1, figsize=(25, 12))
            gs = GridSpec(gs_x, gs_y)

        else:
            gs = GridSpecFromSubplotSpec(gs_x, gs_y, ax.get_subplotspec())
            self.reduced_fig = plt.gcf()

        ax_components = plt.subplot(gs[:, :5])
        ax_loading1 = plt.subplot(gs[:, 6:8])
        ax_loading2 = plt.subplot(gs[:, 10:14])

        # kwargs.update({'ax': ax_components})

        self.plot_samples(show_point_labels=show_point_labels,
                          title=title,
                          show_vectors=show_vectors,
                          show_vector_labels=show_vector_labels,
                          markersize=markersize,
                          legend=legend,
                          ax=ax_components)
        self.plot_loadings(pc=self.x_pc, ax=ax_loading1)
        self.plot_loadings(pc=self.y_pc, ax=ax_loading2)
        sns.despine()
        self.reduced_fig.tight_layout()

        if self.DataModel is not None and not self.featurewise:
            self.plot_violins()
        return self
示例#26
0
    def plot(self, ax=None, n_cols=2):
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot()
        else:
            fig = ax.figure

        pd_results_ = self.pd_results_
        n_features = len(pd_results_)
        feature_names_ = self.feature_names_

        n_cols = min(n_cols, n_features)
        n_rows = int(np.ceil(n_features / float(n_cols)))

        gs = GridSpecFromSubplotSpec(n_cols,
                                     n_rows,
                                     subplot_spec=ax.get_subplotspec())

        axes = {}
        ax.set_subplotspec(gs[0])
        ax.update_params()
        ax.set_position(ax.figbox)
        axes[feature_names_[0]] = ax

        for i in range(1, n_features):
            feature_name = feature_names_[i]
            axes[feature_name] = fig.add_subplot(gs[i])

        artists = {}
        for feature_name, (avg_preds, values) in zip(feature_names_,
                                                     pd_results_):
            cur_ax = axes[feature_name]
            artist = cur_ax.plot(values[0], avg_preds[0].ravel())[0]
            artists[feature_name] = artist
            cur_ax.set_xlabel(feature_name)

        self.axes_ = axes
        self.artists_ = artists
        self.figure_ = ax.get_figure()
        return self
示例#27
0
def showKernel(model,
               layer,
               index_filter=-1,
               index_channel=-1,
               figsize=(25, 25)):
    # Mendapatkan weight / kernel
    weights = model.get_layer(layer).get_weights()[0]

    # Konfigurasi plotting
    fig = plt.figure(figsize=figsize)
    grid = GridSpec(nrows=weights.shape[3], ncols=1, figure=fig)
    nested_grid = []

    # Menampilkan weights per kernel
    range_kernel = range(index_filter, index_filter +
                         1) if index_filter > -1 else range(weights.shape[3])
    for i in range_kernel:
        nrows = int(np.ceil(weights.shape[2] / 6))
        ncols = 6
        nested_grid.append(
            GridSpecFromSubplotSpec(nrows=nrows,
                                    ncols=ncols,
                                    subplot_spec=grid[i]))

        range_channel = range(index_channel, index_channel +
                              1) if index_channel > -1 else range(
                                  weights.shape[2])
        for j in range_channel:
            row_pos = int(np.ceil((j + 1) / ncols))
            col_pos = int(j + 1 - (6 * (row_pos - 1)))

            ax = plt.Subplot(fig, nested_grid[i][row_pos - 1, col_pos - 1])
            ax.imshow(weights[:, :, j, i], cmap='gray')
            ax.title.set_text(f'Kernel {i+1}, channel {j+1}')
            fig.add_subplot(ax)
            plt.xticks([])
            plt.yticks([])

    plt.show()
示例#28
0
    def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None):
        """Plot partial dependence plots.

        Parameters
        ----------
        ax : Matplotlib axes or array-like of Matplotlib axes, default=None
            - If a single axis is passed in, it is treated as a bounding axes
                and a grid of partial dependence plots will be drawn within
                these bounds. The `n_cols` parameter controls the number of
                columns in the grid.
            - If an array-like of axes are passed in, the partial dependence
                plots will be drawn directly into these axes.
            - If `None`, a figure and a bounding axes is created and treated
                as the single axes case.

        n_cols : int, default=3
            The maximum number of columns in the grid plot. Only active when
            `ax` is a single axes or `None`.

        line_kw : dict, default=None
            Dict with keywords passed to the `matplotlib.pyplot.plot` call.
            For one-way partial dependence plots.

        contour_kw : dict, default=None
            Dict with keywords passed to the `matplotlib.pyplot.contourf`
            call for two-way partial dependence plots.

        Returns
        -------
        display : :class:`~sklearn.inspection.PartialDependenceDisplay`
        """

        check_matplotlib_support("plot_partial_dependence")
        import matplotlib.pyplot as plt  # noqa
        from matplotlib.gridspec import GridSpecFromSubplotSpec  # noqa

        if line_kw is None:
            line_kw = {}
        if contour_kw is None:
            contour_kw = {}

        if ax is None:
            _, ax = plt.subplots()

        default_contour_kws = {"alpha": 0.75}
        contour_kw = {**default_contour_kws, **contour_kw}

        default_line_kws = {
            "color": "C0",
            "label": "average" if self.kind == "both" else None,
        }
        line_kw = {**default_line_kws, **line_kw}

        individual_line_kw = line_kw.copy()
        del individual_line_kw["label"]

        if self.kind == "individual" or self.kind == "both":
            individual_line_kw["alpha"] = 0.3
            individual_line_kw["linewidth"] = 0.5

        n_features = len(self.features)
        if self.kind in ("individual", "both"):
            n_ice_lines = self._get_sample_count(len(self.pd_results[0].individual[0]))
            if self.kind == "individual":
                n_lines = n_ice_lines
            else:
                n_lines = n_ice_lines + 1
        else:
            n_ice_lines = 0
            n_lines = 1

        if isinstance(ax, plt.Axes):
            # If ax was set off, it has most likely been set to off
            # by a previous call to plot.
            if not ax.axison:
                raise ValueError(
                    "The ax was already used in another plot "
                    "function, please set ax=display.axes_ "
                    "instead"
                )

            ax.set_axis_off()
            self.bounding_ax_ = ax
            self.figure_ = ax.figure

            n_cols = min(n_cols, n_features)
            n_rows = int(np.ceil(n_features / float(n_cols)))

            self.axes_ = np.empty((n_rows, n_cols), dtype=object)
            if self.kind == "average":
                self.lines_ = np.empty((n_rows, n_cols), dtype=object)
            else:
                self.lines_ = np.empty((n_rows, n_cols, n_lines), dtype=object)
            self.contours_ = np.empty((n_rows, n_cols), dtype=object)

            axes_ravel = self.axes_.ravel()

            gs = GridSpecFromSubplotSpec(
                n_rows, n_cols, subplot_spec=ax.get_subplotspec()
            )
            for i, spec in zip(range(n_features), gs):
                axes_ravel[i] = self.figure_.add_subplot(spec)

        else:  # array-like
            ax = np.asarray(ax, dtype=object)
            if ax.size != n_features:
                raise ValueError(
                    "Expected ax to have {} axes, got {}".format(n_features, ax.size)
                )

            if ax.ndim == 2:
                n_cols = ax.shape[1]
            else:
                n_cols = None

            self.bounding_ax_ = None
            self.figure_ = ax.ravel()[0].figure
            self.axes_ = ax
            if self.kind == "average":
                self.lines_ = np.empty_like(ax, dtype=object)
            else:
                self.lines_ = np.empty(ax.shape + (n_lines,), dtype=object)
            self.contours_ = np.empty_like(ax, dtype=object)

        # create contour levels for two-way plots
        if 2 in self.pdp_lim:
            Z_level = np.linspace(*self.pdp_lim[2], num=8)

        self.deciles_vlines_ = np.empty_like(self.axes_, dtype=object)
        self.deciles_hlines_ = np.empty_like(self.axes_, dtype=object)

        for pd_plot_idx, (axi, feature_idx, pd_result) in enumerate(
            zip(self.axes_.ravel(), self.features, self.pd_results)
        ):
            avg_preds = None
            preds = None
            feature_values = pd_result["values"]
            if self.kind == "individual":
                preds = pd_result.individual
            elif self.kind == "average":
                avg_preds = pd_result.average
            else:  # kind='both'
                avg_preds = pd_result.average
                preds = pd_result.individual

            if len(feature_values) == 1:
                self._plot_one_way_partial_dependence(
                    preds,
                    avg_preds,
                    feature_values[0],
                    feature_idx,
                    n_ice_lines,
                    axi,
                    n_cols,
                    pd_plot_idx,
                    n_lines,
                    individual_line_kw,
                    line_kw,
                )
            else:
                self._plot_two_way_partial_dependence(
                    avg_preds,
                    feature_values,
                    feature_idx,
                    axi,
                    pd_plot_idx,
                    Z_level,
                    contour_kw,
                )

        return self
示例#29
0
def main():
    import os
    import time
    import json

    import matplotlib.pyplot as plt
    import numpy as np
    from keras.losses import binary_crossentropy
    from keras.optimizers import Adam
    from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec

    train_x = np.load('./data/cat_face.npy')

    # -1 から 1 に変換
    train_x = (train_x.astype('float32')-127.5)/127.5

    train_size = train_x.shape[0]

    generator = build_generator()
    print('Generator model:')
    print(generator.summary())

    discriminator = build_discriminator()
    print('Discriminator model:')
    print(discriminator.summary())

    gan = build_gan(generator, discriminator)
    gan.compile(
        loss=binary_crossentropy,
        optimizer=Adam(lr=2e-4, beta_1=.5),
        metrics=['accuracy']
    )

    train_discriminator = build_train_discriminator(discriminator)
    train_discriminator.compile(
        loss=binary_crossentropy,
        optimizer=Adam(lr=2e-4, beta_1=.5),
        metrics=['accuracy']
    )

    batch_size = 128
    epochs = 200
    steps_per_epoch = train_size//batch_size
    # steps_per_epoch = 500

    print(
        f'Train size: {train_size}, '
        f'Batch size: {batch_size}, '
        f'Epochs: {epochs}, '
        f'Total Steps: {steps_per_epoch*epochs}'
    )

    p_r = np.ones((batch_size, 1), dtype=np.float32)
    p_f = np.zeros((batch_size, 1), dtype=np.float32)

    np.random.seed(0)
    test_noise = np.random.uniform(-1, 1, size=(30, 100)).astype(np.float32)
    test_indices = np.random.randint(0, train_size, size=6)
    plot_real_images = train_x[test_indices]
    np.random.seed(None)

    d_loss, d_acc = [], []
    g_loss, g_acc = [], []

    cnt = 0
    for epoch in range(epochs):

        print(f'Epoch {epoch+1}/{epochs}')
        start = time.time()

        # データをシャッフル
        np.random.shuffle(train_x)

        d_loss_epoch, d_acc_epoch = [], []
        g_loss_epoch, g_acc_epoch = [], []

        for step in range(steps_per_epoch):

            # バッチサイズの分だけ画像を選択
            real_images = train_x[step*batch_size:(step+1)*batch_size]

            # バッチサイズの分だけランダムにノイズを生成
            noise = np.random.uniform(-1, 1, size=(batch_size, 100))
            noise = noise.astype(np.float32)

            # generatorにより画像を生成
            fake_images = generator.predict(noise)

            # Discriminatorのtrain
            d_history = train_discriminator.train_on_batch(
                [real_images, fake_images], [p_r, p_f]
            )
            tmp_d_loss = float(d_history[0]/2)
            tmp_d_acc = float((d_history[3]+d_history[4])/2)
            d_loss.append(tmp_d_loss)
            d_acc.append(tmp_d_acc)
            d_loss_epoch.append(tmp_d_loss)
            d_acc_epoch.append(tmp_d_acc)

            # バッチサイズの分だけランダムにノイズを生成
            noise = np.random.uniform(-1, 1, size=(batch_size, 100))
            noise = noise.astype(np.float32)

            # Generatorのtrain
            g_history = gan.train_on_batch(noise, p_r)
            tmp_d_loss = float(g_history[0])
            tmp_d_acc = float(g_history[1])
            g_loss.append(tmp_d_loss)
            g_acc.append(tmp_d_acc)
            g_loss_epoch.append(tmp_d_loss)
            g_acc_epoch.append(tmp_d_acc)

            cnt += 1

        d_loss_std = np.std(d_loss_epoch)
        g_loss_std = np.std(g_loss_epoch)

        d_loss_mean = np.mean(d_loss_epoch)
        g_loss_mean = np.mean(g_loss_epoch)

        d_acc_std = np.std(d_acc_epoch)
        g_acc_std = np.std(g_acc_epoch)

        d_acc_mean = np.mean(d_acc_epoch)
        g_acc_mean = np.mean(g_acc_epoch)

        print(
            f'd_loss: {d_loss_mean:.4f}, '
            f'd_loss_std: {d_loss_std:.4f}, '
            f'd_acc: {d_acc_mean:.2f}, '
            f'd_acc_std: {d_acc_std:.2f}, '
        )

        print(
            f'g_loss: {g_loss_mean:.4f}, '
            f'g_loss_std: {g_loss_std:.4f}, '
            f'g_acc: {g_acc_mean:.2f}, '
            f'g_acc_std: {g_acc_std:.2f}, '
        )

        print(
            f'time: {int(time.time() - start)} s'
        )

        generated_images = generator.predict(test_noise)

        plt.figure(figsize=(6, 6))
        grid0 = GridSpec(1, 2, width_ratios=(5, 1))
        grid00 = GridSpecFromSubplotSpec(6, 5, subplot_spec=grid0[0])
        for i in range(generated_images.shape[0]):
            plt.subplot(grid00[i])
            img = generated_images[i, :, :, :]*127.5 + 127.5
            img = img.astype(np.uint8)
            plt.imshow(img)
            plt.axis('off')

        grid01 = GridSpecFromSubplotSpec(6, 1, subplot_spec=grid0[1])
        for i in range(plot_real_images.shape[0]):
            plt.subplot(grid01[i])
            img = plot_real_images[i, :, :, :]*127.5 + 127.5
            img = img.astype(np.uint8)
            plt.imshow(img)
            plt.axis('off')

        plt.subplots_adjust(left=0, right=1, bottom=0, top=1,
                            hspace=0, wspace=0)

        path = f'figure/dcgan_cat_face/image_{epoch+1:03d}.png'
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path))
        plt.savefig(path)
        plt.close()

        path = f'var/log/dcgan_cat_face_history.json'
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path))
        history = [d_loss, g_loss]
        with open(path, 'w+', encoding='UTF-8') as f:
            json.dump(history, f)

        path = f'model/dcgan_cat_face/model_{epoch+1:03d}.h5'
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path))
        generator.save(path)
示例#30
0
                         xyz_guided.reshape(-1, 3).tolist())

    for s in [idx, rmse_rec, rmse_pred, rmse_guided]:
        result_str += str(s) + ','
    result_str = result_str[:-1] + '\n'

    # layout
    fig = plt.figure(figsize=(10, 5))
    plt.rcParams["font.size"] = 18
    gs_master = GridSpec(nrows=2,
                         ncols=2,
                         height_ratios=[1, 1],
                         width_ratios=[3, 0.1])
    gs_1 = GridSpecFromSubplotSpec(nrows=1,
                                   ncols=4,
                                   subplot_spec=gs_master[0, 0],
                                   wspace=0.05,
                                   hspace=0)
    gs_2 = GridSpecFromSubplotSpec(nrows=1,
                                   ncols=4,
                                   subplot_spec=gs_master[1, 0],
                                   wspace=0.05,
                                   hspace=0)
    gs_3 = GridSpecFromSubplotSpec(nrows=2,
                                   ncols=1,
                                   subplot_spec=gs_master[0:1, 1])

    ax_enh0 = fig.add_subplot(gs_1[0, 0])
    ax_enh1 = fig.add_subplot(gs_1[0, 1])
    ax_enh2 = fig.add_subplot(gs_1[0, 2])
    ax_enh3 = fig.add_subplot(gs_1[0, 3])