Ejemplo n.º 1
0
def features_dend_panel( fig, Z, Z2, width, lw ):
    ax1 = fig.add_axes([-width,0.0,width,1.0], frameon=False)
    Z2['color_list'] = [c.replace('b','k').replace('x','b') for c in Z2['color_list']]
    mh = max(Z[:,2])
    sch._plot_dendrogram(Z2['icoord'], Z2['dcoord'], Z2['ivl'], Z.shape[0] + 1, Z.shape[0] + 1, mh, 'left', no_labels=True, color_list=Z2['color_list'])
    for coll in ax1.collections:
        coll._linewidths = (lw,)
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_xticklabels([])
def features_dend_panel( fig, Z, Z2, width, lw ):
    ax1 = fig.add_axes([-width,0.0,width,1.0], frameon=False)
    Z2['color_list'] = [c.replace('b','k').replace('x','b') for c in Z2['color_list']]
    mh = max(Z[:,2])
    sch._plot_dendrogram(Z2['icoord'], Z2['dcoord'], Z2['ivl'], Z.shape[0] + 1, Z.shape[0] + 1, mh, 'right', no_labels=True, color_list=Z2['color_list'] )
    for coll in ax1.collections:
        coll._linewidths = (lw,)
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_xticklabels([])
Ejemplo n.º 3
0
def samples_dend_panel( fig, Z, Z2, ystart, ylen, lw ):
    ax2 = fig.add_axes([0.0,1.0+ystart,1.0,ylen], frameon=False)
    Z2['color_list'] = [c.replace('b','k') for c in Z2['color_list']]
    mh = max(Z[:,2])
    sch._plot_dendrogram(   Z2['icoord'], Z2['dcoord'], Z2['ivl'], 
                            Z.shape[0] + 1, Z.shape[0] + 1, 
                            mh, 'top', no_labels=True, 
                            color_list=Z2['color_list'] )
    for coll in ax2.collections:
        coll._linewidths = (lw,)
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.set_xticklabels([])
def samples_dend_panel( fig, Z, Z2, ystart, ylen, lw ):
    ax2 = fig.add_axes([0.0,1.0+ystart,1.0,ylen], frameon=False)
    Z2['color_list'] = [c.replace('b','k') for c in Z2['color_list']]
    mh = max(Z[:,2])
    sch._plot_dendrogram(   Z2['icoord'], Z2['dcoord'], Z2['ivl'],
                            Z.shape[0] + 1, Z.shape[0] + 1,
                            mh, 'top', no_labels=True,
                            color_list=Z2['color_list'] )
    for coll in ax2.collections:
        coll._linewidths = (lw,)
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.set_xticklabels([])
Ejemplo n.º 5
0
    def draw(self):

        rat = float(self.ns) / self.nf
        rat *= self.args.cell_aspect_ratio
        x, y = (self.args.image_size, rat *
                self.args.image_size) if rat < 1 else (self.args.image_size /
                                                       rat,
                                                       self.args.image_size)
        fig = plt.figure(figsize=(x, y), facecolor='w')

        cm = pylab.get_cmap(self.args.colormap)
        bottom_col = [
            cm._segmentdata['red'][0][1], cm._segmentdata['green'][0][1],
            cm._segmentdata['blue'][0][1]
        ]
        if self.args.bottom_c:
            bottom_col = self.args.bottom_c
        cm.set_under(bottom_col)
        top_col = [
            cm._segmentdata['red'][-1][1], cm._segmentdata['green'][-1][1],
            cm._segmentdata['blue'][-1][1]
        ]
        if self.args.top_c:
            top_col = self.args.top_c
        cm.set_over(top_col)

        if self.args.nan_c:
            cm.set_bad(self.args.nan_c)

        def make_ticklabels_invisible(ax):
            for tl in ax.get_xticklabels() + ax.get_yticklabels():
                tl.set_visible(False)
            ax.set_xticks([])
            ax.set_yticks([])

        def remove_splines(ax):
            for v in ['right', 'left', 'top', 'bottom']:
                ax.spines[v].set_color('none')

        def shrink_labels(labels, n):
            shrink = lambda x: x[:n / 2] + " [...] " + x[-n / 2:]
            return [(shrink(str(l)) if len(str(l)) > n else l) for l in labels]

        #gs = gridspec.GridSpec( 4, 2,
        #                        width_ratios=[1.0-fr_ns,fr_ns],
        #                        height_ratios=[.03,0.03,1.0-fr_nf,fr_nf],
        #                        wspace = 0.0, hspace = 0.0 )

        fr_ns = float(self.ns) / max([self.ns, self.nf])
        fr_nf = float(self.nf) / max([self.ns, self.nf])

        buf_space = 0.05
        minv = min([buf_space * 8, 8 * rat * buf_space])
        if minv < 0.05:
            buf_space /= minv / 0.05
        metadata_height = self.args.metadata_height if type(
            snames[0]) is tuple and len(snames[0]) > 1 else 0.000001
        gs = gridspec.GridSpec(6,
                               4,
                               width_ratios=[
                                   buf_space, buf_space * 2,
                                   .08 * self.args.fdend_width, 0.9
                               ],
                               height_ratios=[
                                   buf_space, buf_space * 2,
                                   .08 * self.args.sdend_height,
                                   metadata_height,
                                   self.args.metadata_separation, 0.9
                               ],
                               wspace=0.0,
                               hspace=0.0)

        ax_hm = plt.subplot(gs[23], axisbg=bottom_col)
        ax_metadata = plt.subplot(gs[15], axisbg=bottom_col)
        ax_hm_y2 = ax_hm.twinx()

        norm_f = matplotlib.colors.Normalize
        if self.args.log_scale:
            norm_f = matplotlib.colors.LogNorm
        elif self.args.sqrt_scale:
            norm_f = SqrtNorm
        minv, maxv = 0.0, None

        maps, values, ndv = [], [], 0
        if type(snames[0]) is tuple and len(snames[0]) > 1:
            metadata = zip(*[list(s[1:]) for s in snames])
            for m in metadata:
                mmap = dict([(v[1], ndv + v[0])
                             for v in enumerate(list(set(m)))])
                values.append([mmap[v] for v in m])
                ndv += len(mmap)
                maps.append(mmap)
            dcols = []
            mdmat = np.matrix(values)
            while len(dcols) < ndv:
                dcols += self.dcols
            cmap = matplotlib.colors.ListedColormap(dcols[:ndv])
            bounds = [float(f) - 0.5 for f in range(ndv + 1)]
            imm = ax_metadata.imshow(
                mdmat,  #origin='lower', 
                interpolation='nearest',
                aspect='auto',
                extent=[0, self.nf, 0, self.ns],
                cmap=cmap,
                vmin=bounds[0],
                vmax=bounds[-1],
            )
            remove_splines(ax_metadata)
            ax_metadata_y2 = ax_metadata.twinx()
            ax_metadata_y2.set_ylim(0, len(self.fnames_meta))
            ax_metadata.set_yticks([])
            ax_metadata_y2.set_ylim(0, len(self.fnames_meta))
            ax_metadata_y2.tick_params(length=0)
            ax_metadata_y2.set_yticks(np.arange(len(self.fnames_meta)) + 0.5)
            ax_metadata_y2.set_yticklabels(self.fnames_meta[::-1],
                                           va='center',
                                           size=self.args.flabel_size)
        else:
            ax_metadata.set_yticks([])

        ax_metadata.set_xticks([])

        im = ax_hm.imshow(
            self.numpy_matrix,  #origin='lower', 
            interpolation='nearest',
            aspect='auto',
            extent=[0, self.nf, 0, self.ns],
            cmap=cm,
            vmin=self.args.minv,
            vmax=self.args.maxv,
            norm=norm_f(vmin=minv if minv > 0.0 else None, vmax=maxv))

        #ax_hm.set_ylim([0,800])
        ax_hm.set_xticks(np.arange(len(list(snames))) + 0.5)
        if not self.args.no_slabels:
            snames_short = shrink_labels(
                list([s[0] for s in snames]) if type(snames[0]) is tuple else
                snames, self.args.max_slabel_len)
            ax_hm.set_xticklabels(snames_short,
                                  rotation=90,
                                  va='top',
                                  ha='center',
                                  size=self.args.slabel_size)
        else:
            ax_hm.set_xticklabels([])
        ax_hm_y2.set_ylim([0, self.ns])
        ax_hm_y2.set_yticks(np.arange(len(fnames)) + 0.5)
        if not self.args.no_flabels:
            fnames_short = shrink_labels(fnames, self.args.max_flabel_len)
            ax_hm_y2.set_yticklabels(fnames_short,
                                     va='center',
                                     size=self.args.flabel_size)
        else:
            ax_hm_y2.set_yticklabels([])
        ax_hm.set_yticks([])
        remove_splines(ax_hm)
        ax_hm.tick_params(length=0)
        ax_hm_y2.tick_params(length=0)
        #ax_hm.set_xlim([0,self.ns])
        ax_cm = plt.subplot(gs[3], axisbg='r', frameon=False)
        #fig.colorbar(im, ax_cm, orientation = 'horizontal', spacing = 'proportional', format = ticker.LogFormatterMathtext() )
        fig.colorbar(im,
                     ax_cm,
                     orientation='horizontal',
                     spacing='proportional' if self.args.sqrt_scale else
                     'uniform')  # , format = ticker.LogFormatterMathtext() )

        if not self.args.no_sclustering:
            ax_den_top = plt.subplot(gs[11], axisbg='r', frameon=False)
            sph._plot_dendrogram(self.sdendrogram['icoord'],
                                 self.sdendrogram['dcoord'],
                                 self.sdendrogram['ivl'],
                                 self.ns + 1,
                                 self.nf + 1,
                                 1,
                                 'top',
                                 no_labels=True,
                                 color_list=self.sdendrogram['color_list'])
            ymax = max([max(a) for a in self.sdendrogram['dcoord']])
            ax_den_top.set_ylim([0, ymax])
            make_ticklabels_invisible(ax_den_top)
        if not self.args.no_fclustering:
            ax_den_right = plt.subplot(gs[22], axisbg='b', frameon=False)
            sph._plot_dendrogram(self.fdendrogram['icoord'],
                                 self.fdendrogram['dcoord'],
                                 self.fdendrogram['ivl'],
                                 self.ns + 1,
                                 self.nf + 1,
                                 1,
                                 'right',
                                 no_labels=True,
                                 color_list=self.fdendrogram['color_list'])
            xmax = max([max(a) for a in self.fdendrogram['dcoord']])
            ax_den_right.set_xlim([xmax, 0])
            make_ticklabels_invisible(ax_den_right)

        if not self.args.out:
            plt.show()
        else:
            fig.savefig(self.args.out, bbox_inches='tight', dpi=self.args.dpi)
            if maps:
                self.make_legend(maps, fnames_meta, self.args.legend_file)
Ejemplo n.º 6
0
plt.title("Gold Tree", fontsize=26)
plt.yticks([])
plt.tick_params(labelsize=16)
n = 18
mh = 2
color_list = ['c', 'c', 'c', 'c', 'g', 'k', 'r', 'r', 'r', 'r', 'm', 'm', 'm', 'm', 'k', 'k']
p = 30
orientation = 'top'
no_labels = False
print("Plotting")
plt.tight_layout()
with plt.rc_context({'lines.linewidth': 2.0}):
    _plot_dendrogram(icoord, dcoord, ivl, p, n, mh, orientation,
                     no_labels, color_list,
                     leaf_font_size=26.,
                     leaf_rotation=None,
                     contraction_marks=None,
                     ax=None,
                     above_threshold_color="k")

plt.savefig(save_dir + "goldtree_sentences" + str(int(threshold * 100)) + ".png")

# Rabinovych tree
# Coordinates for the tree
#
icoord = [[5.0, 5.0, 15.0, 15.0], [10.0, 10.0, 25.0, 25.0], [17.5, 17.5, 35.0, 35.0], [45.0, 45.0, 55.0, 55.0],
          [50.0, 50.0, 65.0, 65.0], [26.25, 26.25, 57.5, 57.5], [75.0, 75.0, 85.0, 85.0],
          [41.87, 41.87, 80.0, 80.0], [95.0, 95.0, 105.0, 105.0], [100.0, 100.0, 115.0, 115.0],
          [125.0, 125.0, 135.0, 135.0], [107.5, 107.5, 130.0, 130.0], [145.0, 145.0, 155.0, 155.0],
          [150.0, 150.0, 165.0, 165.0], [118.75, 118.75, 157.5, 157.5], [60.93, 60.93, 138.13, 138.13]]
Ejemplo n.º 7
0
def dendrogram_plot(matrix, output_dir, factor_names):
    """Make dendrogram plot recording at which threshold factors and codes merge.

  This plotting function produce a dendrogram plot recording at which factors of
  variation and latent codes are most related by running the union-find
  algorithm https://en.wikipedia.org/wiki/Disjoint-set_data_structure on the
  matrix relating factors of variation and latent codes.

  Args:
    matrix: Input matrix of shape [num_factors, num_codes] encoding the
      statistical relation between factors and codes.
    output_dir: Directory to save the plot in.
    factor_names: Lables for the factors of variation to be used in the plot.

  Returns:
    Dictionary containing the threshold ID of each merging events and which
    factors were merged.
  """
    tmp = pd.melt(pd.DataFrame(matrix).reset_index(), id_vars="index")
    # The columns of the dataframe are: index, variable and value.
    tmp = tmp.to_numpy()
    # Sort the matrix by threshold
    tmp = tmp[tmp[:, -1].argsort()[::-1]]
    # The codes have index code + num_factors.
    tmp[:, 1] += matrix.shape[0]

    # Initialize dictionaries for cluster IDs and size.
    size = {}
    cluster_id = {}
    for i in range(matrix.shape[0]):
        size[i] = 1
        cluster_id[i] = i
    # Initialize dendrogram matrix. Each row is an event, each event is composed
    # by [cluster_id_1, cluster_id_2, threshold, size of the new cluster]
    z = np.zeros([matrix.shape[0] - 1, 4])
    # Each factor of variation is in its own tree. So the maximum cluster ID we
    # have is matrix.shape[0]-1.
    n_clusters = matrix.shape[0] - 1
    nodes = list(range(matrix.shape[0] + matrix.shape[1]))
    idx_found = 0
    discovered = {}
    # Run the Union-Find Algorithm
    for id_i, i in enumerate(tmp):
        # Record if we just discovered a new factor of variation.
        if i[0] not in discovered:
            discovered[i[0]] = id_i
        # Merge trees.
        z, cluster_id, size, n_clusters, idx_found = _union(
            nodes, i[0], i[1], id_i, z, cluster_id, size, matrix, n_clusters,
            idx_found)
    # Obtain the dendrogram plot data structure from the matrix z
    fig, ax = plt.subplots()
    dn = hierarchy.dendrogram(z, ax=ax, orientation="left", no_plot=True)
    # Create a dictionary to map the location on the plot to the leaf
    id_to_leaf = {}
    id_conv = 5
    for l in dn["leaves"]:
        id_to_leaf[id_conv] = l
        id_conv += 10
    # Update the dcoord to when the cluster was actually discovered.
    for d, i in zip(dn["dcoord"], dn["icoord"]):
        if d[0] == 0:
            idx = id_to_leaf[i[0]]
            d[0] = discovered[idx]
        if d[-1] == 0:
            idx = id_to_leaf[i[-1]]
            d[-1] = discovered[idx]
    # Set colors to be all the same.
    dn["color_list"] = ["b"] * len(dn["color_list"])
    dn["ivl"] = np.array(factor_names)[dn["leaves"]]

    hierarchy._plot_dendrogram(
        dn["icoord"],
        dn["dcoord"],
        dn["ivl"],
        p=30,  # pylint: disable=protected-access
        n=z.shape[0] + 1,
        mh=max(z[:, 2]),
        orientation="right",
        no_labels=False,
        color_list=dn["color_list"],
        leaf_font_size=None,
        leaf_rotation=None,
        contraction_marks=None,
        ax=ax,
        above_threshold_color="b")
    plt.xlabel("Threshold")
    plt.ylabel("Factor")
    thresholds = tmp[:, 2]
    thresholds_ids = range(0, thresholds.shape[0], 10)
    plt.xticks(
        thresholds_ids,
        np.around(np.array(thresholds, dtype="float32")[thresholds_ids], 2))
    output_path = output_dir + ".png"
    with tf.gfile.Open(output_path, "wb") as path:
        fig.savefig(path, bbox_inches="tight")
    return report_merges(z, matrix.shape[0])
Ejemplo n.º 8
0
    def draw( self ):

        rat = float(self.ns)/self.nf
        rat *= self.args.cell_aspect_ratio
        x,y = (self.args.image_size,rat*self.args.image_size) if rat < 1 else (self.args.image_size/rat,self.args.image_size)
        fig = plt.figure( figsize=(x,y), facecolor = 'w'  )

        cm = pylab.get_cmap(self.args.colormap)
        bottom_col = [  cm._segmentdata['red'][0][1],
                        cm._segmentdata['green'][0][1],
                        cm._segmentdata['blue'][0][1]   ]
        if self.args.bottom_c:
            bottom_col = self.args.bottom_c
        cm.set_under( bottom_col )
        top_col = [  cm._segmentdata['red'][-1][1],
                     cm._segmentdata['green'][-1][1],
                     cm._segmentdata['blue'][-1][1]   ]
        if self.args.top_c:
            top_col = self.args.top_c
        cm.set_over( top_col )

        if self.args.nan_c:
            cm.set_bad( self.args.nan_c  )

        def make_ticklabels_invisible(ax):
            for tl in ax.get_xticklabels() + ax.get_yticklabels():
                 tl.set_visible(False)
            ax.set_xticks([])
            ax.set_yticks([])
      
        def remove_splines( ax ):
            for v in ['right','left','top','bottom']:
                ax.spines[v].set_color('none')

        def shrink_labels( labels, n ):
            shrink = lambda x: x[:n/2]+" [...] "+x[-n/2:]
            return [(shrink(str(l)) if len(str(l)) > n else l) for l in labels]
        

        #gs = gridspec.GridSpec( 4, 2, 
        #                        width_ratios=[1.0-fr_ns,fr_ns], 
        #                        height_ratios=[.03,0.03,1.0-fr_nf,fr_nf], 
        #                        wspace = 0.0, hspace = 0.0 )
        
        fr_ns = float(self.ns)/max([self.ns,self.nf])
        fr_nf = float(self.nf)/max([self.ns,self.nf])
       
        buf_space = 0.05
        minv = min( [buf_space*8, 8*rat*buf_space] )
        if minv < 0.05:
            buf_space /= minv/0.05
        metadata_height = self.args.metadata_height if type(snames[0]) is tuple and len(snames[0]) > 1 else 0.000001 
        gs = gridspec.GridSpec( 6, 4, 
                                width_ratios=[ buf_space, buf_space*2, .08*self.args.fdend_width,0.9], 
                                height_ratios=[ buf_space, buf_space*2, .08*self.args.sdend_height, metadata_height, self.args.metadata_separation, 0.9], 
                                wspace = 0.0, hspace = 0.0 )

        ax_hm = plt.subplot(gs[23], axisbg = bottom_col  )
        ax_metadata = plt.subplot(gs[15], axisbg = bottom_col  )
        ax_hm_y2 = ax_hm.twinx() 

        norm_f = matplotlib.colors.Normalize
        if self.args.log_scale:
            norm_f = matplotlib.colors.LogNorm
        elif self.args.sqrt_scale:
            norm_f = SqrtNorm
        minv, maxv = 0.0, None

        maps, values, ndv = [], [], 0
        if type(snames[0]) is tuple and len(snames[0]) > 1:
            metadata = zip(*[list(s[1:]) for s in snames])
            for m in metadata:
                mmap = dict([(v[1],ndv+v[0]) for v in enumerate(list(set(m)))])
                values.append([mmap[v] for v in m])
                ndv += len(mmap)
                maps.append(mmap)
            dcols = [] 
            mdmat = np.matrix(values)
            while len(dcols) < ndv:
                dcols += self.dcols
            cmap = matplotlib.colors.ListedColormap(dcols[:ndv]) 
            bounds = [float(f)-0.5 for f in range(ndv+1)]
            imm = ax_metadata.imshow( mdmat, #origin='lower', 
                    interpolation = 'nearest',  
                                    aspect='auto', 
                                    extent = [0, self.nf, 0, self.ns], 
                                    cmap=cmap,
                                    vmin=bounds[0],
                                    vmax=bounds[-1],
                                    )
            remove_splines( ax_metadata )
            ax_metadata_y2 = ax_metadata.twinx() 
            ax_metadata_y2.set_ylim(0,len(self.fnames_meta))
            ax_metadata.set_yticks([])
            ax_metadata_y2.set_ylim(0,len(self.fnames_meta))
            ax_metadata_y2.tick_params(length=0)
            ax_metadata_y2.set_yticks(np.arange(len(self.fnames_meta))+0.5)
            ax_metadata_y2.set_yticklabels(self.fnames_meta[::-1], va='center',size=self.args.flabel_size)
        else:
            ax_metadata.set_yticks([])

        ax_metadata.set_xticks([])
        
        im = ax_hm.imshow( self.numpy_matrix, #origin='lower', 
                                interpolation = 'nearest',  aspect='auto', 
                                extent = [0, self.nf, 0, self.ns], 
                                cmap=cm, 
                                vmin=self.args.minv,
                                vmax=self.args.maxv, 
                                norm = norm_f( vmin=minv if minv > 0.0 else None, vmax=maxv)
                                )
        
        #ax_hm.set_ylim([0,800])
        ax_hm.set_xticks(np.arange(len(list(snames)))+0.5)
        if not self.args.no_slabels:
            snames_short = shrink_labels( list([s[0] for s in snames]) if type(snames[0]) is tuple else snames, self.args.max_slabel_len )
            ax_hm.set_xticklabels(snames_short,rotation=90,va='top',ha='center',size=self.args.slabel_size)
        else:
            ax_hm.set_xticklabels([])
        ax_hm_y2.set_ylim([0,self.ns])
        ax_hm_y2.set_yticks(np.arange(len(fnames))+0.5)
        if not self.args.no_flabels:
            fnames_short = shrink_labels( fnames, self.args.max_flabel_len )
            ax_hm_y2.set_yticklabels(fnames_short,va='center',size=self.args.flabel_size)
        else:
            ax_hm_y2.set_yticklabels( [] )
        ax_hm.set_yticks([])
        remove_splines( ax_hm )
        ax_hm.tick_params(length=0)
        ax_hm_y2.tick_params(length=0)
        #ax_hm.set_xlim([0,self.ns])
        ax_cm = plt.subplot(gs[3], axisbg = 'r', frameon = False)
        #fig.colorbar(im, ax_cm, orientation = 'horizontal', spacing = 'proportional', format = ticker.LogFormatterMathtext() )
        fig.colorbar(im, ax_cm, orientation = 'horizontal', spacing='proportional' if self.args.sqrt_scale else 'uniform' ) # , format = ticker.LogFormatterMathtext() )

        if not self.args.no_sclustering:
            ax_den_top = plt.subplot(gs[11], axisbg = 'r', frameon = False)
            sph._plot_dendrogram( self.sdendrogram['icoord'], self.sdendrogram['dcoord'], self.sdendrogram['ivl'],
                                  self.ns + 1, self.nf + 1, 1, 'top', no_labels=True,
                                  color_list=self.sdendrogram['color_list'] )
            ymax = max([max(a) for a in self.sdendrogram['dcoord']])
            ax_den_top.set_ylim([0,ymax])
            make_ticklabels_invisible( ax_den_top )
        if not self.args.no_fclustering:
            ax_den_right = plt.subplot(gs[22], axisbg = 'b', frameon = False)
            sph._plot_dendrogram(   self.fdendrogram['icoord'], self.fdendrogram['dcoord'], self.fdendrogram['ivl'],
                                    self.ns + 1, self.nf + 1, 1, 'right', no_labels=True,
                                    color_list=self.fdendrogram['color_list'] )
            xmax = max([max(a) for a in self.fdendrogram['dcoord']])
            ax_den_right.set_xlim([xmax,0])
            make_ticklabels_invisible( ax_den_right )

        
        if not self.args.out:
            plt.show( )
        else:
            fig.savefig( self.args.out, bbox_inches='tight', dpi = self.args.dpi )
            if maps: 
                self.make_legend( maps, fnames_meta, self.args.legend_file )