コード例 #1
0
def mmm_nested_split(_graph,
                     _iters,
                     _layer_specs,
                     _title='',
                     random_seed=0,
                     figsize=(30, 20),
                     theme=None,
                     path=None,
                     dark=False):
    xs = []
    for n, d in _graph.nodes(data=True):
        xs.append(d['x'])

    util_funcs.plt_setup(dark=dark)

    assert isinstance(_layer_specs, (list, tuple))
    fig, axes = plt.subplots(1, 3, figsize=figsize)

    pop_map, landuse_maps, capacitance_maps, flow_maps = mmm_layercake_phd(_graph,
                                                                           _iters,
                                                                           _layer_specs=
                                                                           _layer_specs[0],
                                                                           random_seed=random_seed)
    plotter(axes[0], _iters, xs, _res_factor=1, _plot_maps=[pop_map,
                                                            capacitance_maps[0],
                                                            landuse_maps[0]],
            _plot_colours=['#555555', '#d32f2f', '#2e7d32'])
    title = _title + f'l.u: {np.round(np.sum(landuse_maps), 2)}'
    style_ax(axes[0], title, _iters, dark=dark)

    # both
    pop_map, landuse_maps, capacitance_maps, flow_maps = mmm_layercake_phd(_graph,
                                                                           _iters,
                                                                           _layer_specs=_layer_specs,
                                                                           random_seed=random_seed)
    plotter(axes[1], _iters, xs, _res_factor=1, _plot_maps=[pop_map,
                                                            np.sum(capacitance_maps, axis=0),
                                                            landuse_maps[0],
                                                            landuse_maps[1]],
            _plot_colours=['#555555', '#d32f2f', '#2e7d32', '#0064b7'])
    title = _title + f'l.u: {np.round(np.sum(landuse_maps), 2)}'
    style_ax(axes[1], title, _iters, dark=dark)

    pop_map, landuse_maps, capacitance_maps, flow_maps = mmm_layercake_d(_graph,
                                                                         _iters,
                                                                         _layer_specs=
                                                                         _layer_specs[1],
                                                                         random_seed=random_seed)
    plotter(axes[2], _iters, xs, _res_factor=1, _plot_maps=[pop_map,
                                                            capacitance_maps[0],
                                                            landuse_maps[0]],
            _plot_colours=['#555555', '#d32f2f', '#0064b7'])
    title = _title + f'l.u: {np.round(np.sum(landuse_maps), 2)}'
    style_ax(axes[2], title, _iters, dark=dark)

    if theme is not None:
        fig.suptitle(theme)

    if path is not None:
        plt.savefig(path, dpi=300)
コード例 #2
0
def mmm_single(_graph,
               _iters,
               _layer_specs,
               _title='',
               random_seed=0,
               figsize=(12, 20),
               theme=None,
               path=None,
               dark=False):
    xs = []
    for n, d in _graph.nodes(data=True):
        xs.append(d['x'])

    util_funcs.plt_setup(dark=dark)

    if isinstance(_layer_specs, dict):
        pop_map, landuse_maps, capacitance_maps, flow_maps = mmm_layercake_phd(_graph,
                                                                               _iters,
                                                                               _layer_specs=_layer_specs,
                                                                               random_seed=random_seed)
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        caps = capacitance_maps[0]
        lus = landuse_maps[0]
        flows = flow_maps[0]
        plotter(ax, _iters, xs, _res_factor=1,
                _plot_maps=[pop_map, caps, lus, flows],
                _plot_scales=(1, 1.5, 0.75, 1))
        title = _title + f'l.u: {np.round(np.sum(landuse_maps), 2)}'
        style_ax(ax, title, _iters, dark=dark)

    else:
        assert isinstance(_layer_specs, (list, tuple))
        n_ax = len(_layer_specs)
        fig, axes = plt.subplots(1, n_ax, figsize=figsize)
        if isinstance(_title, str):
            _title = [_title] * n_ax
        else:
            assert isinstance(_title, (list, tuple))
            assert len(_title) == len(_layer_specs)
        for ax, title, layer_spec in zip(axes, _title, _layer_specs):
            pop_map, landuse_maps, capacitance_maps, flow_maps = mmm_layercake_phd(_graph,
                                                                                   _iters,
                                                                                   _layer_specs=layer_spec,
                                                                                   random_seed=random_seed)
            caps = capacitance_maps[0]
            lus = landuse_maps[0]
            flows = flow_maps[0]
            plotter(ax, _iters, xs, _res_factor=1, _plot_maps=[pop_map, caps, lus, flows])
            title = title + f'l.u: {np.round(np.sum(landuse_maps), 2)}'
            style_ax(ax, title, _iters, dark=dark)

    if theme is not None:
        fig.suptitle(theme)

    if path is not None:
        plt.savefig(path, dpi=300)
コード例 #3
0
def plot_latents(df,
                 X_trans,
                 Z_mu,
                 labels,
                 distances,
                 latent_titles=None,
                 suptitle=None,
                 path=None):
    '''
    plot latents
    '''
    util_funcs.plt_setup()
    fig, axes = plt.subplots(3, 8, figsize=(12, 8))
    # plot correlations
    for latent_idx, ax in enumerate(axes[0]):
        heatmap_corrs = plot_funcs.correlate_heatmap(len(labels),
                                                     len(distances),
                                                     X_trans,
                                                     Z_mu[:, latent_idx])
        plot_funcs.plot_heatmap(ax,
                                heatmap=heatmap_corrs,
                                row_labels=labels,
                                col_labels=distances,
                                set_row_labels=latent_idx == 0,
                                cbar=True)
        if latent_titles is not None:
            t = latent_titles[latent_idx]
            ax.set_title(t)
    # plot maps
    for city_ax_row, city_name, x, y in zip([axes[1], axes[2]],
                                            ['Cambridge', 'Milton Keynes'],
                                            (545700, 485970),
                                            (258980, 236920)):
        x_extents = (x - 2500, x + 2500)
        y_extents = (y - 4500, y + 4500)
        for latent_idx, ax in enumerate(city_ax_row):
            plot_funcs.plot_scatter(ax,
                                    df.x,
                                    df.y,
                                    Z_mu[:, latent_idx],
                                    x_extents=x_extents,  # 10000
                                    y_extents=y_extents,  # 10000
                                    relative_extents=False,
                                    s_min=0,
                                    s_max=0.8,
                                    c_exp=5,
                                    s_exp=3.5)
            ax.set_xlabel(f'{city_name}  #{latent_idx}')

    if suptitle is not None:
        plt.suptitle(suptitle)
    if path is not None:
        plt.savefig(path, dpi=300)
    else:
        plt.show()
コード例 #4
0
def plot_prob_clusters(X_raw,
                       cluster_probs,
                       n_components,
                       path_theme,
                       xs,
                       ys,
                       max_only=False,
                       plt_cmap='gist_ncar',
                       shape_exp=0.5,
                       suptitle='GMM VaDE',
                       rasterized=True):
    # get the assignments based on maximum probability
    cluster_assignments = np.argmax(cluster_probs, axis=1)
    # get the colours for each cluster based on mean mixed uses
    mu_means, mu_means_normalised, sorted_cluster_idx = make_mu_means(
        X_raw, cluster_assignments, n_components, shape_exp)
    # plot the axes
    util_funcs.plt_setup()
    fig, axes = plt.subplots(3, 7, figsize=(7, 10))
    counter = 0
    cmap = plt.cm.get_cmap(plt_cmap)
    for ax_row in axes:
        for ax in ax_row:
            if counter < cluster_probs.shape[1]:
                cluster_idx = sorted_cluster_idx[counter]
                c = cmap(mu_means_normalised[cluster_idx])
                vals = cluster_probs[:, cluster_idx]
                if max_only:
                    max_idx = (cluster_assignments == cluster_idx)
                    # shape c and s manually
                    v = np.full(len(cluster_probs), 0.0)
                    v[max_idx] = vals[max_idx]
                    s = np.copy(v)
                    s[max_idx] *= 0.75
                    s[max_idx] += 0.25
                else:
                    s = vals
                s **= 1
                # override vals with explicit "c" and "s"
                plot_scatter(fig, ax, xs, ys, c=c, s=s, rasterized=rasterized)
                ax.set_xlabel(f'Cluster #{cluster_idx + 1}')
            counter += 1
    plt.suptitle(suptitle)
    path = f'../phd-doc/doc/images/signatures/{path_theme}_cluster_composite'
    if max_only:
        path += '_max'
    path += '.pdf'
    plt.savefig(path, dpi=300)
コード例 #5
0
def simple_plot(_G, _path=None, plot_geoms=True):
    # manual ax for plotting additional circles
    util_funcs.plt_setup()
    # create new plot
    fig, target_ax = plt.subplots(1, 1)
    plot.plot_nX(_G,
                 labels=False,
                 plot_geoms=plot_geoms,
                 node_size=10,
                 edge_width=1,
                 x_lim=(min_x, max_x),
                 y_lim=(min_y, max_y),
                 ax=target_ax)
    if _path is not None:
        plt.savefig(_path, facecolor='#2e2e2e')
    else:
        return target_ax
コード例 #6
0
def pop_corr_plot(city_data, theme_1, theme_2, towns_data, xlabel, sup_title):
    new_towns = []
    other_towns = []
    for i, d in bound_data.iterrows():
        if d['city_type'] in ['New Town']:
            new_towns.append(d['pop_id'])
        else:
            other_towns.append(d['pop_id'])

    util_funcs.plt_setup()
    fig, axes = plt.subplots(2, 1, figsize=(7, 5))

    max_pop_id = city_data.city_pop_id.max()
    for n, dist in enumerate(['200', '1600']):  # , '400', '800',

        key_1 = theme_1.format(dist=dist)
        key_2 = theme_2.format(dist=dist)

        axes[n].set_ylabel(r'spearman $\rho$' + r' $d_{max}=' + f'{dist}m$')
        axes[n].set_xlabel(xlabel)

        x = []
        y = []
        s = []
        o_id = []
        o_n = []
        nt_x = []
        nt_y = []
        nt_s = []
        nt_id = []
        nt_n = []
        for pop_id in reversed(range(1, int(max_pop_id) + 1)):
            pop = towns_data[towns_data.pop_id ==
                             pop_id]['city_population'].values[0]
            t_n = towns_data[towns_data.pop_id ==
                             pop_id]['city_name'].values[0]
            d_1 = city_data[city_data.city_pop_id == pop_id][key_1]
            d_2 = city_data[city_data.city_pop_id == pop_id][key_2]
            if len(d_1):
                # d_1, d_2 = phd_util.prep_xy(d_1, d_2)
                p_r, p = stats.spearmanr(d_1, d_2)
                size = ((1 - pop_id / max_pop_id) * 20 + 5)
                if pop_id in new_towns:
                    nt_x.append(pop)
                    nt_y.append(p_r)
                    nt_s.append(size)
                    nt_id.append(pop_id)
                    nt_n.append(t_n)
                else:
                    x.append(pop)
                    y.append(p_r)
                    s.append(size)
                    o_id.append(pop_id)
                    o_n.append(t_n)

        # filter other towns to same population range
        poly_min_x = np.nanmin(nt_x)
        poly_max_x = np.nanmax(nt_x)
        other_towns_filtered = []
        for o_t in other_towns:
            # returns a dataframe (pandas doesn't know that this is a single row) so index from values
            o_t_num = towns_data[towns_data.pop_id ==
                                 o_t]['city_population'].values[0]
            if o_t_num >= poly_min_x and o_t_num <= poly_max_x:
                other_towns_filtered.append(o_t)

        # get averages - don't take average of average, but compute directly to avoid ecological correlation
        nt_d_1 = city_data[city_data['city_pop_id'].isin(new_towns)][key_1]
        nt_d_2 = city_data[city_data['city_pop_id'].isin(new_towns)][key_2]
        # nt_d_1, nt_d_2 = phd_util.prep_xy(nt_d_1, nt_d_2)
        nt_corr, _p = stats.spearmanr(nt_d_1, nt_d_2)
        other_d_1 = city_data[city_data['city_pop_id'].isin(
            other_towns_filtered)][key_1]
        other_d_2 = city_data[city_data['city_pop_id'].isin(
            other_towns_filtered)][key_2]
        # other_d_1, other_d_2 = phd_util.prep_xy(other_d_1, other_d_2)
        other_corr, _p = stats.spearmanr(other_d_1, other_d_2)

        nt_col = '#d32f2f'
        o_col = '#0064b7'

        # plot
        axes[n].scatter(nt_x,
                        nt_y,
                        c=nt_col,
                        s=nt_s,
                        alpha=0.7,
                        marker='o',
                        edgecolors='white',
                        linewidths=0.3,
                        zorder=3)
        axes[n].scatter(x,
                        y,
                        c=o_col,
                        s=s,
                        alpha=0.4,
                        marker='o',
                        edgecolors='white',
                        linewidths=0.3,
                        zorder=2)

        # add lines
        axes[n].hlines(nt_corr,
                       xmin=poly_min_x,
                       xmax=poly_max_x,
                       colors=nt_col,
                       lw=2,
                       alpha=0.5,
                       linestyle='-',
                       zorder=4,
                       label=f'new towns $r={round(nt_corr, 3)}$')
        axes[n].hlines(other_corr,
                       xmin=poly_min_x,
                       xmax=poly_max_x,
                       colors=o_col,
                       lw=2,
                       alpha=0.5,
                       linestyle='-',
                       zorder=4,
                       label=f'equiv. towns $r={round(other_corr, 3)}$')

        y_axes_cushion = (np.nanmax(y) - np.nanmin(y)) * 0.1
        y_text_cushion = (np.nanmax(y) - np.nanmin(y)) * 0.085
        # for axes extents
        upper_y_extent = np.nanmax(y) + y_axes_cushion
        lower_y_extent = np.nanmin(y) - y_axes_cushion
        # for text and lines
        upper_y_end = np.nanmax(y) + y_text_cushion
        lower_y_end = np.nanmin(y) - y_text_cushion

        # background polygon
        axes[n].fill(
            [poly_min_x, poly_min_x, poly_max_x, poly_max_x],
            [lower_y_extent, upper_y_extent, upper_y_extent, lower_y_extent],
            c='grey',
            lw=0,
            alpha=0.1,
            zorder=1)

        # new towns
        for t_x, t_y, t_id, t_n in zip(nt_x, nt_y, nt_id, nt_n):
            # to avoid overlap, plot certain from top and others from bottom
            # top
            if t_id in [29, 39, 63, 80, 126, 153, 194, 244]:
                align = 'top'
                y_end = upper_y_end
            # bottom
            else:
                align = 'bottom'
                y_end = lower_y_end
            axes[n].text(t_x * 1.02,
                         y_end,
                         t_n,
                         rotation=90,
                         verticalalignment=align,
                         fontdict={'size': 5},
                         color='#D3A1A6')
            axes[n].vlines(t_x,
                           ymin=t_y,
                           ymax=y_end,
                           color='#D3A1A6',
                           lw=0.5,
                           alpha=0.4)

        # other towns
        for t_x, t_y, t_id, t_n in zip(x, y, o_id, o_n):
            # to avoid overlap, plot certain from top and others from bottom
            # top
            if t_id in [3, 6, 8, 10, 12]:
                align = 'top'
                y_end = upper_y_end
            # bottom
            elif t_id in [1, 2, 4, 5, 7, 9, 11, 13]:
                align = 'bottom'
                y_end = lower_y_end
            else:
                continue
            axes[n].text(t_x * 1.02,
                         y_end,
                         t_n,
                         rotation=90,
                         verticalalignment=align,
                         fontdict={'size': 5},
                         color='silver')
            axes[n].vlines(t_x,
                           ymin=t_y,
                           ymax=y_end,
                           color='silver',
                           lw=0.5,
                           alpha=0.4)

        axes[n].set_xlim(left=5000, right=10**7)
        axes[n].set_ylim(bottom=lower_y_extent, top=upper_y_extent)
        axes[n].set_xscale('log')
        axes[n].legend(loc=2)
    fig.suptitle(sup_title)
    path = f'../phd-doc/doc/images/predictive/corr/{theme_1.strip("_{dist}")}_{theme_2.strip("_{dist}")}.pdf'
    plt.savefig(path, dpi=300)
コード例 #7
0
    ['pop_id',
     'city_name',
     'city_type',
     'city_area',
     'city_population',
     'city_species_count',
     'city_species_unique',
     'city_streets_len',
     'city_intersections_count'],
    'analysis.city_boundaries_150',
    'WHERE city_population IS NOT NULL ORDER BY pop_id')
Style = util_funcs.Style()

#  %% plot population vs. total number of POI
# clear previous figures and set matplotlib defaults
util_funcs.plt_setup()
fig, axes = plt.subplots(1, 2, sharey='row', figsize=(7, 3.5))
# data
# d = df.dropna(how='all')
pop = df.city_population.values
species_count = df.city_species_count.values
count_dens = pop / species_count
# sizes
pop_log = np.log10(pop)
pop_log_norm = plt.Normalize()(pop_log)
count_dens_norm = plt.Normalize()(count_dens)
# curve fit to a powerlaw and plot
x_fit_line = np.arange(pop.min(), pop.max())


def powerFunc(x, c, z):
コード例 #8
0
            dir_sub_path = pathlib.Path(dir_path / f'anim_ml_{anim_key}')
            if not dir_sub_path.exists():
                dir_sub_path.mkdir(exist_ok=True, parents=True)
            np.save(str(dir_sub_path / f'{title}_{data_key}.npy'), data)
# otherwise load
else:
    anim_a = []
    anim_b = []
    anim_c = []
    for anim, anim_key in zip([anim_a, anim_b, anim_c], ['a', 'b', 'c']):
        dir_sub_path = pathlib.Path(dir_path / f'anim_ml_{anim_key}')
        for data_key in ['dwell_map', 'landuse_maps', 'capacitance_maps']:
            anim.append(np.load(str(dir_sub_path / f'{title}_{data_key}.npy')))

#  %% plot animation
util_funcs.plt_setup(dark=False)
fig, axes = plt.subplots(1, 3, figsize=(24, 8))
# prepare lists for data
anims = [anim_a, anim_b, anim_c]  #
gs = [graph_a, graph_b, graph_c]  #
ax_themes = ['Historic', 'Mesh-like', 'Tree-like']
xs = []
ys = []
ls = []
dwells = []
caps = []
lus = []
txts_iters = []
txts_lus = []
txts_caps = []
コード例 #9
0
def mmm_single(_graph,
               _iters,
               _layer_specs,
               _title='',
               seed=False,
               random_seed=0,
               figsize=(12, 20),
               theme=None,
               path=None,
               dark=False):
    xs = []
    for n, d in _graph.nodes(data=True):
        xs.append(d['x'])

    background_col = '#ffffff'
    if dark:
        background_col = '#2e2e2e'

    util_funcs.plt_setup()

    if isinstance(_layer_specs, dict):
        pop_map, landuse_maps, capacitance_maps = mmm_layercake_b(
            _graph,
            _iters,
            _layer_specs=_layer_specs,
            seed=seed,
            random_seed=random_seed)
        fig, ax = plt.subplots(1, 1, figsize=figsize, facecolor=background_col)
        caps = capacitance_maps[0]
        lus = landuse_maps[0] * caps
        plotter(ax, _iters, xs, _res_factor=1, _plot_maps=[pop_map, caps, lus])
        title = _title + f'l.u: {landuse_maps[0].sum()}'
        style_ax(ax, title, _iters)

    else:
        assert isinstance(_layer_specs, (list, tuple))
        n_ax = len(_layer_specs)
        fig, axes = plt.subplots(1,
                                 n_ax,
                                 figsize=figsize,
                                 facecolor=background_col)
        if isinstance(_title, str):
            _title = [_title] * n_ax
        else:
            assert isinstance(_title, (list, tuple))
            assert len(_title) == len(_layer_specs)
        for ax, title, layer_spec in zip(axes, _title, _layer_specs):
            pop_map, landuse_maps, capacitance_maps = mmm_layercake_b(
                _graph,
                _iters,
                _layer_specs=layer_spec,
                seed=seed,
                random_seed=random_seed)
            caps = capacitance_maps[0]
            lus = landuse_maps[0] * caps
            plotter(ax,
                    _iters,
                    xs,
                    _res_factor=1,
                    _plot_maps=[pop_map, caps, lus])
            title = title + f'l.u: {landuse_maps[0].sum()}'
            style_ax(ax, title, _iters)

    if theme is not None:
        fig.suptitle(theme)

    if path is not None:
        plt.savefig(path,
                    facecolor=fig.get_facecolor(),
                    edgecolor='none',
                    dpi=300,
                    transparent=True)

    plt.show()
コード例 #10
0
def plot_components(component_idxs,
                    feature_labels,
                    distances,
                    X,
                    X_latent,
                    xs,
                    ys,
                    title_tags=None,
                    corr_tags=None,
                    map_tags=None,
                    loadings=None,
                    dark=False,
                    label_all=True,
                    s_min=0,
                    s_max=0.6,
                    c_exp=1,
                    s_exp=1,
                    cbar=False,
                    figsize=None,
                    rasterized=True):
    n_rows = 2
    n_cols = len(component_idxs)
    if figsize is None:
        figsize = (n_cols * 1.5, 8)
    util_funcs.plt_setup(dark=dark)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    # create heatmaps for original vectors plotted against the top PCA components
    for n, comp_idx in enumerate(component_idxs):
        print(f'processing component {comp_idx + 1}')
        # loadings / correlations
        if loadings is not None:
            heatmap_corr = loadings[comp_idx].reshape(len(feature_labels),
                                                      len(distances))
        else:
            heatmap_corr = correlate_heatmap(len(feature_labels),
                                             len(distances), X,
                                             X_latent[:, comp_idx])
        heatmap_ax = axes[0][n]
        if title_tags is not None:
            heatmap_ax.set_title(title_tags[n])
        im = plot_heatmap(heatmap_ax,
                          heatmap_corr,
                          row_labels=feature_labels,
                          col_labels=distances,
                          set_row_labels=(label_all or n == 0),
                          set_col_labels=True,
                          dark=dark)
        if corr_tags is not None:
            heatmap_ax.set_xlabel(corr_tags[n])
        # map
        map_ax = axes[1][n]
        col_data = X_latent[:, comp_idx]
        plot_scatter(fig,
                     map_ax,
                     xs,
                     ys,
                     col_data,
                     dark=dark,
                     s_min=s_min,
                     s_max=s_max,
                     c_exp=c_exp,
                     s_exp=s_exp,
                     rasterized=rasterized)
        if map_tags is not None:
            map_ax.set_xlabel(map_tags[n])
    if cbar:
        fig.colorbar(im,
                     ax=axes[0],
                     aspect=50,
                     pad=0.01,
                     orientation='vertical',
                     ticks=[-1, 0, 1],
                     shrink=0.5)
コード例 #11
0
 def epoch_writes(self, epoch_step):
     if self.writer is not None:
         with self.writer.as_default():
             super().epoch_writes(epoch_step)
             # extract images
             Z_mu, Z_log_var, Z = self.model.encode(self.X_val,
                                                    training=False)
             x_hat = self.model.decode(Z_mu, training=False)
             # images of mean x vs x_hat vs diff
             x_img = np.mean(self.X_val,
                             axis=0).reshape(len(self.labels),
                                             len(self.distances))
             x_hat_img = np.mean(x_hat,
                                 axis=0).reshape(len(self.labels),
                                                 len(self.distances))
             # stack if passing data directly
             # stacked_img = np.vstack([x_img, x_hat_img])
             # stacked_img = np.reshape(stacked_img, (-1, len(labels), len(distances), 1))
             util_funcs.plt_setup()
             fig, axes = plt.subplots(1, 3, figsize=(6, 8))
             plot_funcs.plot_heatmap(axes[0],
                                     x_img,
                                     row_labels=self.labels,
                                     col_labels=self.distances)
             plot_funcs.plot_heatmap(axes[1],
                                     x_hat_img,
                                     set_row_labels=False,
                                     col_labels=self.distances)
             plot_funcs.plot_heatmap(axes[2],
                                     x_img - x_hat_img,
                                     set_row_labels=False,
                                     col_labels=self.distances)
             tf.summary.image('x | x hat | diff',
                              plot_to_image(fig),
                              step=epoch_step)
             # images of latents
             latent_dim = Z_mu.shape[1]
             util_funcs.plt_setup()
             fig, axes = plt.subplots(1, latent_dim, figsize=(12, 8))
             for l_idx in range(latent_dim):
                 corr = plot_funcs.correlate_heatmap(
                     len(self.labels), len(self.distances), self.X_val,
                     Z_mu[:, l_idx])
                 plot_funcs.plot_heatmap(axes[l_idx],
                                         corr,
                                         row_labels=self.labels,
                                         set_row_labels=l_idx == 0,
                                         col_labels=self.distances)
             tf.summary.image('latents',
                              plot_to_image(fig),
                              step=epoch_step)
             # histograms
             if hasattr(self.model, 'sampling'):
                 tf.summary.histogram(
                     'Z mu biases',
                     self.model.sampling.Z_mu_layer.weights[1],
                     step=epoch_step)
                 tf.summary.histogram(
                     'Z mu weights',
                     self.model.sampling.Z_mu_layer.weights[0],
                     step=epoch_step)
                 tf.summary.histogram(
                     'Z logvar biases',
                     self.model.sampling.Z_logvar_layer.weights[1],
                     step=epoch_step)
                 tf.summary.histogram(
                     'Z logvar weights',
                     self.model.sampling.Z_logvar_layer.weights[0],
                     step=epoch_step)
コード例 #12
0
                                                                     iters,
                                                                     _layer_specs=layer_specs,
                                                                     random_seed=0)

# %%
from tqdm import tqdm
import numpy as np
from src import util_funcs
import matplotlib.pyplot as plt
from celluloid import Camera
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.collections import LineCollection

cmap = LinearSegmentedColormap.from_list('cityseer', ['#64c1ff', '#d32f2f'])

util_funcs.plt_setup(dark=True)
fig, ax = plt.subplots(1, 1, figsize=(20, 20))
camera = Camera(fig)

pos = {}
for n, d in graph.nodes(data=True):
    pos[n] = (d['x'], d['y'])
xs = np.array([v[0] for v in pos.values()])
ys = np.array([v[1] for v in pos.values()])

lines = []
for s, e in graph.edges():
    s_x = graph.nodes[s]['x']
    s_y = graph.nodes[s]['y']
    e_x = graph.nodes[e]['x']
    e_y = graph.nodes[e]['y']