示例#1
0
文件: bands.py 项目: juijan/sisl
    def draw_bands(self, filtered_bands, spin_texture, **kwargs):

        if spin_texture["show"]:
            # Create the normalization for the colorscale of spin_moments.
            self._spin_texture_norm = Normalize(spin_texture["values"].min(),
                                                spin_texture["values"].max())
            self._spin_texture_colorscale = spin_texture["colorscale"]

        super().draw_bands(filtered_bands=filtered_bands,
                           spin_texture=spin_texture,
                           **kwargs)

        if spin_texture["show"]:
            # Add the colorbar for spin texture.
            self.figure.colorbar(self._colorbar)

        # Add the ticks
        tick_vals = getattr(filtered_bands, "ticks", None)
        if tick_vals is not None:
            self.axes.set_xticks(tick_vals)
        tick_labels = getattr(filtered_bands, "ticklabels", None)
        if tick_labels is not None:
            self.axes.set_xticklabels(tick_labels)
        # Set the limits
        self.axes.set_xlim(*filtered_bands.k.values[[0, -1]])
        self.axes.set_ylim(filtered_bands.min(), filtered_bands.max())
示例#2
0
def create_cmap(values, colors):

    from matplotlib.pyplot import Normalize
    import matplotlib

    norm = Normalize(min(values), max(values))
    tuples = list(zip(map(norm, values), colors))
    cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples)
    return cmap, norm
示例#3
0
def plot_frames_imshow(
    images,
    labels=None,
    nim=11,
    avg=50,
    interval=1,
    do1h=True,
    transpose=False,
    label_mapping=None,
):
    from matplotlib.colors import ListedColormap, LinearSegmentedColormap
    from matplotlib.pyplot import Normalize

    colors = ["red", "black", "green"]
    cmap = LinearSegmentedColormap.from_list("name", colors)
    if avg > images.shape[1]:
        avg = images.shape[1]

    rnge = range(0, np.maximum(images.shape[1] // avg, 1), interval)

    import pylab as plt

    plt.figure(figsize=[nim + 2, 16])
    import matplotlib.gridspec as gridspec

    if not transpose:
        gs = gridspec.GridSpec(len(rnge), nim)
    else:
        gs = gridspec.GridSpec(nim, len(rnge))
    plt.subplots_adjust(left=0, bottom=0, right=1, top=0.95, wspace=0.0, hspace=0.04)
    if labels is not None:
        if do1h:
            categories = labels.argmax(axis=1)
        else:
            categories = labels
    else:
        categories = range(len(images))
    s = []
    for j in range(nim):
        norm = Normalize(-0.1, 0.1)
        for e, i in enumerate(rnge):
            if not transpose:
                ax = plt.subplot(gs[e, j])
            else:
                ax = plt.subplot(gs[j, e])
            plt.imshow(
                images[j, i * avg : (i * avg + avg), 0, :, :].mean(axis=0).T
                - images[j, i * avg : (i * avg + avg), 1, :, :].mean(axis=0).T,
                cmap=cmap,
                norm=norm,
            )
            plt.xticks([])

            if i == 0 and label_mapping is not None:
                plt.title(label_mapping[int(categories[j])], fontsize=10)
            plt.yticks([])
        s.append(images[j].sum())
 def __init__(self, iterates, cmap_name=None):
     """
     Iterate over objects
     returns tuple of object and color
     :param iterates: some iterable
     :param cmap_name: name of the color map
     """
     self.iterates = iterates
     self.cmap = get_cmap(cmap_name)
     self.norm = Normalize()
示例#5
0
def plot_optimized_network(network, blocking=True, save_plot=True):
    _title = 'Optimized network vs. non-optimized'
    figure(_title)

    graph = network.network_plot
    plot_layout = spring_layout(graph)

    balls = network.get_ball_distribution()

    print(network.ref_distribution)
    print(balls)

    min_val = 0
    max_val = max([np_max(network.ref_distribution), np_max(balls)])

    cmap = cm.Greys
    color_vals = cm.ScalarMappable(cmap=cmap,
                                   norm=Normalize(vmin=min_val, vmax=max_val))
    color_vals._A = []

    subplot(2, 1, 1)
    title('Initial network')
    colorbar(color_vals)

    draw_networkx_edges(graph, plot_layout, alpha=.3)
    draw_networkx_nodes(graph,
                        plot_layout,
                        node_size=100,
                        edgecolors='k',
                        node_color=network.ref_distribution,
                        cmap=cmap,
                        vmin=min_val,
                        vmax=max_val)
    axis('off')  # Disable axis

    subplot(2, 1, 2)
    title('Optimized network')
    colorbar(color_vals)

    draw_networkx_edges(graph, plot_layout, alpha=.3)
    draw_networkx_nodes(graph,
                        plot_layout,
                        node_size=100,
                        edgecolors='k',
                        node_color=balls,
                        cmap=cmap,
                        vmin=min_val,
                        vmax=max_val)
    axis('off')
    draw()

    if save_plot:
        savefig('../results/optimized_network.png')
    show(block=blocking)  # Open matplotlib window
示例#6
0
def test_center_cmap():
    """Test centering of colormap."""
    from matplotlib.colors import LinearSegmentedColormap
    from matplotlib.pyplot import Normalize
    cmap = center_cmap(cm.get_cmap("RdBu"), -5, 10)

    assert isinstance(cmap, LinearSegmentedColormap)

    # get new colors for values -5 (red), 0 (white), and 10 (blue)
    new_colors = cmap(Normalize(-5, 10)([-5, 0, 10]))
    # get original colors for 0 (red), 0.5 (white), and 1 (blue)
    reference = cm.RdBu([0., 0.5, 1.])
    assert_allclose(new_colors, reference)
    # new and old colors at 0.5 must be different
    assert not np.allclose(cmap(0.5), reference[1])
示例#7
0
def plot_network(network,
                 blocking=True,
                 netx_plot=False,
                 size=fig_size,
                 weights=None,
                 file_name=None,
                 plot_edges=False,
                 alph=.05):
    fig = figure(figsize=size)
    rcParams.update({
        'font.size': plot_font_size,
        'mathtext.default': 'regular'
    })
    ax = fig.gca()
    ax.axis('off')  # Disable axis

    graph = network if netx_plot else network.network_plot
    plot_layout = kamada_kawai_layout(graph)
    cmap = cm.get_cmap('coolwarm')

    sizes, edge_colors, node_colors = 80, 'k', 'w'
    if weights is not None:
        sizes = [50 if weight == 0 else 80 for weight in weights]
        node_colors = weights

        min_val, max_val = min(weights), max(weights)

    if plot_edges: draw_networkx_edges(graph, plot_layout, alpha=alph)
    draw_networkx_nodes(graph,
                        plot_layout,
                        node_size=sizes,
                        linewidths=.5,
                        edgecolors='k',
                        node_color=node_colors,
                        cmap=cmap)
    draw()

    if weights is not None:
        plt = cm.ScalarMappable(cmap=cmap, norm=Normalize(vmin=0, vmax=1))
        plt._A = []
        colorbar(plt)

    if file_name is not None:
        savefig(file_name, bbox_inches='tight', pad_inches=0)
    show(block=blocking)  # Open matplotlib window
示例#8
0
def apply_color_map(colors, image):
    """
    Applies the color specified by colors to the input image.

    Input:

        colors - list of colors in color map

        image - image to apply color map to with shape (n, n)

    Output:

        color_image - image with shape (n, n, 3)

    """
    cmap = get_color_map(colors)
    norm = Normalize(vmin=0, vmax=len(colors))
    color_image = cmap(norm(image))[:, :, 0:3]  # drop alpha
    return color_image
def scatter_xyc(points, smooth=0, div=10, ax=None, **options):
    """
    Draws a 2D graph (X,Y, color), the color is chosen based on a value *f(x,y)*
    The function requires :epkg:`matploblib` and :epkg:`scipy`.

    @param      points      (x,y, z=f(x,y) )
    @param      smooth      applies n times a smoothing I * M (convolutional)
    @param      div         number of divisions for axis
    @param      options     others options: xlabel, ylabel, title, figsize (if ax is None)
    @return                 fig, ax (fig is None if ax was sent to the function)

    .. plot::
        :include-source:

        import random
        def generate_gauss(x, y, sigma, N=1000):
            res = []
            for i in range(N):
                u = random.gauss(0, 1)
                a = sigma * u + x
                b = sigma * random.gauss(0, 1) + y + u
                res.append((a, b))
            return res
        def f(a, b):
            return (a ** 2 + b ** 2) ** 0.5
        nuage1 = generate_gauss(0, 0, 3)
        nuage2 = generate_gauss(3, 4, 2)
        nuage = [(a, b, f(a, b)) for a, b in nuage1] + [(a, b, f(a, b)) for a, b in nuage2]
        import matplotlib.pyplot as plt
        from ensae_teaching_cs.helpers.matplotlib_helper_xyz import scatter_xyc
        fig, ax = scatter_xyc(nuage, title="example with random observations")
        plt.show()

    The error ``ValueError: Unknown projection '3d'`` is raised when the line
    ``from mpl_toolkits.mplot3d import Axes3D`` is missing.
    """
    if ax is None:
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=options.get('figsize', None))
    else:
        fig = None

    x = [_[0] for _ in points]
    y = [_[1] for _ in points]
    z = [_[2] for _ in points]

    tri = Triangulation(x, y)

    plt.tricontour(tri, z, 15, linewidths=0.5, colors='k')
    plt.tricontourf(tri,
                    z,
                    15,
                    cmap=plt.cm.rainbow,
                    norm=Normalize(vmax=numpy.abs(z).max(),
                                   vmin=-numpy.abs(z).max()))
    plt.colorbar(ax=ax)
    ax.scatter(x, y, c='b', s=5, zorder=10)
    ax.set_xlim(min(x), max(x))
    ax.set_ylim(min(y), max(y))

    if "xlabel" in options:
        ax.set_xlabel(options["xlabel"])
    if "ylabel" in options:
        ax.set_ylabel(options["ylabel"])
    if "title" in options:
        ax.set_title(options["title"])
    return fig, ax
示例#10
0
文件: qa.py 项目: wkopp/htseq
def main():

    try:
        import matplotlib
    except ImportError:
        sys.stderr.write("This script needs the 'matplotlib' library, which ")
        sys.stderr.write("was not found. Please install it." )
    matplotlib.use('PDF')
    from matplotlib import pyplot

    # Matplotlib <1.5 uses normalize, so this block will be deprecated
    try:
        from matplotlib.pyplot import Normalize
    except ImportError:
        from matplotlib.pyplot import normalize as Normalize


    # **** Parse command line ****

    optParser = optparse.OptionParser( usage = "%prog [options] read_file",
        description=
        "This script take a file with high-throughput sequencing reads " +
        "(supported formats: SAM, Solexa _export.txt, FASTQ, Solexa " +
        "_sequence.txt) and performs a simply quality assessment by " +
        "producing plots showing the distribution of called bases and " +
        "base-call quality scores by position within the reads. The " +
        "plots are output as a PDF file.",
        epilog =
        "Written by Simon Anders ([email protected]), European Molecular Biology " +
        " Laboratory (EMBL). (c) 2010. Released under the terms of the GNU General " +
        " Public License v3. Part of the 'HTSeq' framework, version %s." % HTSeq.__version__ )
    optParser.add_option( "-t", "--type", type="choice", dest="type",
        choices = ("sam", "bam", "solexa-export", "fastq", "solexa-fastq"),
        default = "sam", help="type of read_file (one of: sam [default], bam, " +
        "solexa-export, fastq, solexa-fastq)" )
    optParser.add_option( "-o", "--outfile", type="string", dest="outfile",
        help="output filename (default is <read_file>.pdf)" )
    optParser.add_option( "-r", "--readlength", type="int", dest="readlen",
        help="the maximum read length (when not specified, the script guesses from the file" )
    optParser.add_option( "-g", "--gamma", type="float", dest="gamma",
        default = 0.3,
        help="the gamma factor for the contrast adjustment of the quality score plot" )
    optParser.add_option( "-n", "--nosplit", action="store_true", dest="nosplit",
        help="do not split reads in unaligned and aligned ones" )
    optParser.add_option( "-m", "--maxqual", type="int", dest="maxqual", default=41,
        help="the maximum quality score that appears in the data (default: 41)" )

    if len( sys.argv ) == 1:
        optParser.print_help()
        sys.exit(1)

    (opts, args) = optParser.parse_args()

    if len( args ) != 1:
        sys.stderr.write( sys.argv[0] + ": Error: Please provide one argument (the read_file).\n" )
        sys.stderr.write( "  Call with '-h' to get usage information.\n" )
        sys.exit( 1 )

    readfilename = args[0]

    if opts.type == "sam":
        readfile = HTSeq.SAM_Reader( readfilename )
        isAlnmntFile = True
    elif opts.type == "bam":
        readfile = HTSeq.BAM_Reader( readfilename )
        isAlnmntFile = True
    elif opts.type == "solexa-export":
        readfile = HTSeq.SolexaExportReader( readfilename )
        isAlnmntFile = True
    elif opts.type == "fastq":
        readfile = HTSeq.FastqReader( readfilename )
        isAlnmntFile = False
    elif opts.type == "solexa-fastq":
        readfile = HTSeq.FastqReader( readfilename, "solexa" )
        isAlnmntFile = False
    else:
        sys.error( "Oops." )

    twoColumns = isAlnmntFile and not opts.nosplit

    if opts.outfile is None:
        outfilename = os.path.basename( readfilename ) + ".pdf"
    else:
        outfilename = opts.outfile


    # **** Get read length ****

    if opts.readlen is not None:
        readlen = opts.readlen
    else:
        readlen = 0
        if isAlnmntFile:
            reads = ( a.read for a in readfile )
        else:
            reads = readfile
        for r in islice( reads, 10000 ):
            if len( r ) > readlen:
                readlen = len( r )

    max_qual = opts.maxqual
    gamma = opts.gamma


    # **** Initialize count arrays ****

    base_arr_U = numpy.zeros( ( readlen, 5 ), numpy.int )
    qual_arr_U = numpy.zeros( ( readlen, max_qual+1 ), numpy.int )
    if twoColumns:
        base_arr_A = numpy.zeros( ( readlen, 5 ), numpy.int )
        qual_arr_A = numpy.zeros( ( readlen, max_qual+1 ), numpy.int )


    # **** Main counting loop ****

    i = 0
    try:
        for a in readfile:
            if isAlnmntFile:
                r = a.read
        else:
            r = a
        if twoColumns and (isAlnmntFile and a.aligned):
            r.add_bases_to_count_array( base_arr_A )
            r.add_qual_to_count_array( qual_arr_A )
        else:
            r.add_bases_to_count_array( base_arr_U )
            r.add_qual_to_count_array( qual_arr_U )
        i += 1
        if (i % 200000) == 0:
            print(i, "reads processed")
    except:
        sys.stderr.write( "Error occured in: %s\n" %
            readfile.get_line_number_string() )
        raise
    print(i, "reads processed")


    # **** Normalize result ****

    def norm_by_pos( arr ):
        arr = numpy.array( arr, numpy.float )
        arr_n = ( arr.T / arr.sum( 1 ) ).T
        arr_n[ arr == 0 ] = 0
        return arr_n

    def norm_by_start( arr ):
        arr = numpy.array( arr, numpy.float )
        arr_n = ( arr.T / arr.sum( 1 )[ 0 ] ).T
        arr_n[ arr == 0 ] = 0
        return arr_n


    base_arr_U_n = norm_by_pos( base_arr_U )
    qual_arr_U_n = norm_by_start( qual_arr_U )
    nreads_U = base_arr_U[0,:].sum()
    if twoColumns:
        base_arr_A_n = norm_by_pos( base_arr_A )
        qual_arr_A_n = norm_by_start( qual_arr_A )
        nreads_A = base_arr_A[0,:].sum()


    # **** Make plot ****

    def plot_bases( arr ):
        xg = numpy.arange( readlen )
        pyplot.plot( xg, arr[ : , 0 ], marker='.', color='red')
        pyplot.plot( xg, arr[ : , 1 ], marker='.', color='darkgreen')
        pyplot.plot( xg, arr[ : , 2 ], marker='.',color='lightgreen')
        pyplot.plot( xg, arr[ : , 3 ], marker='.',color='orange')
        pyplot.plot( xg, arr[ : , 4 ], marker='.',color='grey')
        pyplot.axis( (0, readlen-1, 0, 1 ) )
        pyplot.text( readlen*.70, .9, "A", color="red" )
        pyplot.text( readlen*.75, .9, "C", color="darkgreen" )
        pyplot.text( readlen*.80, .9, "G", color="lightgreen" )
        pyplot.text( readlen*.85, .9, "T", color="orange" )
        pyplot.text( readlen*.90, .9, "N", color="grey" )

    pyplot.figure()
    pyplot.subplots_adjust( top=.85 )
    pyplot.suptitle( os.path.basename(readfilename), fontweight='bold' )

    if twoColumns:

        pyplot.subplot( 221 )
        plot_bases( base_arr_U_n )
        pyplot.ylabel( "proportion of base" )
        pyplot.title( "non-aligned reads\n%.0f%% (%.3f million)" %
        ( 100. * nreads_U / (nreads_U+nreads_A), nreads_U / 1e6 ) )

        pyplot.subplot( 222 )
        plot_bases( base_arr_A_n )
        pyplot.title( "aligned reads\n%.0f%% (%.3f million)" %
        ( 100. * nreads_A / (nreads_U+nreads_A), nreads_A / 1e6 ) )

        pyplot.subplot( 223 )
        pyplot.pcolor( qual_arr_U_n.T ** gamma, cmap=pyplot.cm.Greens,
        norm=Normalize( 0, 1 ) )
        pyplot.axis( (0, readlen-1, 0, max_qual+1 ) )
        pyplot.xlabel( "position in read" )
        pyplot.ylabel( "base-call quality score" )

        pyplot.subplot( 224 )
        pyplot.pcolor( qual_arr_A_n.T ** gamma, cmap=pyplot.cm.Greens,
        norm=Normalize( 0, 1 ) )
        pyplot.axis( (0, readlen-1, 0, max_qual+1 ) )
        pyplot.xlabel( "position in read" )

    else:

        pyplot.subplot( 211 )
        plot_bases( base_arr_U_n )
        pyplot.ylabel( "proportion of base" )
        pyplot.title( "%.3f million reads" % ( nreads_U / 1e6 ) )

        pyplot.subplot( 212 )
        pyplot.pcolor( qual_arr_U_n.T ** gamma, cmap=pyplot.cm.Greens,
        norm=Normalize( 0, 1 ) )
        pyplot.axis( (0, readlen-1, 0, max_qual+1 ) )
        pyplot.xlabel( "position in read" )
        pyplot.ylabel( "base-call quality score" )


    pyplot.savefig( outfilename )
示例#11
0
def plot(
        result,
        readfilename,
        outfile,
        max_qual,
        gamma,
        primary_only=False,
        ):

    def plot_bases(arr, ax):
        xg = np.arange(readlen)
        ax.plot(xg, arr[:, 0], marker='.', color='red')
        ax.plot(xg, arr[:, 1], marker='.', color='darkgreen')
        ax.plot(xg, arr[:, 2], marker='.', color='lightgreen')
        ax.plot(xg, arr[:, 3], marker='.', color='orange')
        ax.plot(xg, arr[:, 4], marker='.', color='grey')
        ax.set_xlim(0, readlen-1)
        ax.set_ylim(0, 1)
        ax.text(readlen*.70, .9, "A", color="red")
        ax.text(readlen*.75, .9, "C", color="darkgreen")
        ax.text(readlen*.80, .9, "G", color="lightgreen")
        ax.text(readlen*.85, .9, "T", color="orange")
        ax.text(readlen*.90, .9, "N", color="grey")

    if outfile is None:
        outfilename = os.path.basename(readfilename) + ".pdf"
    else:
        outfilename = outfile

    isAlnmntFile = result['isAlnmntFile']
    readlen = result['readlen']
    twoColumns = result['twoColumns']

    base_arr_U_n = result['base_arr_U_n']
    qual_arr_U_n = result['qual_arr_U_n']
    nreads_U = result['nreads_U']

    if twoColumns:
        base_arr_A_n = result['base_arr_A_n']
        qual_arr_A_n = result['qual_arr_A_n']
        nreads_A = result['nreads_A']

    cur_backend = matplotlib.get_backend()

    try:
        matplotlib.use('PDF')

        fig = plt.figure()
        fig.subplots_adjust(top=.85)
        fig.suptitle(os.path.basename(readfilename), fontweight='bold')

        if twoColumns:

            ax = fig.add_subplot(221)
            plot_bases(base_arr_U_n, ax)
            ax.set_ylabel("proportion of base")
            ax.set_title(
                    "non-aligned reads\n{:.0%} ({:.4f} million)".format(
                    1.0 * nreads_U / (nreads_U+nreads_A),
                    1.0 * nreads_U / 1e6,
                    ))

            ax2 = fig.add_subplot(222)
            plot_bases(base_arr_A_n, ax2)
            ax2.set_title(
                    "{:}\n{:.0%} ({:.4f} million)".format(
                        'aligned reads' if primary_only else 'alignments',
                        1.0 * nreads_A / (nreads_U+nreads_A),
                        1.0 * nreads_A / 1e6,
                    ))

            ax3 = fig.add_subplot(223)
            ax3.pcolor(
                    qual_arr_U_n.T ** gamma,
                    cmap=plt.cm.Greens,
                    norm=Normalize(0, 1))
            ax3.set_xlim(0, readlen-1)
            ax3.set_ylim(0, max_qual+1)
            ax3.set_xlabel("position in read")
            ax3.set_ylabel("base-call quality score")

            ax4 = fig.add_subplot(224)
            ax4.pcolor(
                    qual_arr_A_n.T ** gamma,
                    cmap=plt.cm.Greens,
                    norm=Normalize(0, 1))
            ax4.set_xlim(0, readlen-1)
            ax4.set_ylim(0, max_qual+1)
            ax4.set_xlabel("position in read")

        else:

            ax = fig.add_subplot(211)
            plot_bases(base_arr_U_n, ax)
            ax.set_ylabel("proportion of base")
            ax.set_title("{:.3f} million {:}".format(
                1.0 * nreads_U / 1e6,
                'reads' if (not isAlnmntFile) or primary_only else 'alignments',
                ))

            ax2 = fig.add_subplot(212)
            ax2.pcolor(
                    qual_arr_U_n.T ** gamma,
                    cmap=plt.cm.Greens,
                    norm=Normalize(0, 1))
            ax2.set_xlim(0, readlen-1)
            ax2.set_ylim(0, max_qual+1)
            ax2.set_xlabel("position in read")
            ax2.set_ylabel("base-call quality score")

        fig.savefig(outfilename)

    finally:
        matplotlib.use(cur_backend)
示例#12
0
    def draw_dendrogram(self, ax, pairs, values, labels, lw=20., alpha=0.4, cmap='viridis'):
        try:
            from matplotlib import collections as mc
            from matplotlib.pyplot import Arrow
            from matplotlib.pyplot import Normalize
            from matplotlib.pyplot import cm
        except ImportError:
            raise ImportError('You must install the matplotlib library to plot the minimum spanning tree.')

        min_index, max_index = min(pairs), max(pairs)
        if min_index < 0:
            raise ValueError('Indices should be non-negative')

        size = int(len(pairs) / 2 + 1)

        union_size = size
        if max_index > union_size - 1:
            union_size = max_index + 1
        union_size += 2

        # we will create Union Find as usual
        uf, sz = np.zeros(2 * union_size, dtype=int), np.ones(union_size)
        next_label = union_size + 1
        # also we need links
        l, r = np.arange(0, 2*union_size), np.arange(0, 2*union_size)

        next_label = union_size + 1
        for j in range(0, size - 1):
            a, b = pairs[2 * j], pairs[2 * j + 1]

            # we will stack first cluster on the left of second
            aa = a
            while aa != r[aa]:
                aa = r[aa]
            bb = b
            while bb != l[bb]:
                bb = l[bb]
            l[bb] = aa
            r[aa] = bb # linking


            aa = a
            while aa != l[aa]:
                aa = l[aa]
            bb = b
            while bb != r[bb]:
                bb = r[bb]
            l[next_label] = aa # marking the borders
            r[next_label] = bb

            aa, bb = self.fast_find(uf, a), self.fast_find(uf, b)
            uf[aa] = uf[bb] = next_label

            # i = next_label - union_size
            # a2 = (uf[a] != 0) * (aa - union_size)
            # b2 = (uf[b] != 0) * (bb - union_size)
            # na, nb = sz[a2], sz[b2]
            # sz[i] = na + nb

            next_label += 1

        x_arr = self.arrange_nodes_on_x_axis(uf, union_size, l, r, 200.)

        norm = len(np.unique(pairs))
        sm = cm.ScalarMappable(cmap=cmap,
                                   norm=Normalize(0, norm))
        sm.set_array(norm)

        colors = self.get_dendro_colors(labels)
        heights = {}
        uf.fill(0)
        next_label = union_size + 1
        for j in range(0, size - 1):
            v = np.log2(1. + values[j]) # logarithm
            heights[next_label] = v

            a, b = pairs[2 * j], pairs[2 * j + 1]

            # i = next_label - union_size
            aa, bb = self.fast_find(uf, a), self.fast_find(uf, b)
            x_arr[next_label] = (x_arr[r[aa]] + x_arr[l[bb]])/2.
            uf[aa] = uf[bb] = next_label
            next_label += 1

            # a = (uf[a] != 0) * (aa - union_size)
            # b = (uf[b] != 0) * (bb - union_size)
            # na, nb = sz[a], sz[b]
            # sz[i] = na + nb

            ha, hb = 0, 0
            xa, xb = x_arr[aa], x_arr[bb]
            if aa in heights:
                ha = heights[aa]
            if bb in heights:
                hb = heights[bb]

            c = 'gray'
            if labels[a] == labels[b] and labels[a] > 0:
                c = colors[labels[a]]

            ax.plot([xa, xa], [ha, v], color=c)
            ax.plot([xb, xb], [hb, v], color=c)
            ax.plot([xa, xb], [v, v], color=c)

        ax.set_xticks([])
        for side in ('right', 'top', 'bottom'):
            ax.spines[side].set_visible(False)
        ax.set_ylabel('distance')

        # line_collection.set_array(self._mst[:, 2].T)
        return ax