Пример #1
0
def matrix_base_mpl(matrix,
                    positions,
                    substitutions,
                    conservation=None,
                    secondary_structure=None,
                    wildtype_sequence=None,
                    min_value=None,
                    max_value=None,
                    ax=None,
                    colormap=plt.cm.RdBu_r,
                    colormap_conservation=plt.cm.Oranges,
                    na_color="#bbbbbb",
                    title=None,
                    position_label_size=8,
                    substitution_label_size=8,
                    show_colorbar=True,
                    colorbar_indicate_bounds=False,
                    show_wt_char=True,
                    label_filter=None,
                    secondary_structure_style=None):
    """
    Matplotlib-based mutation matrix plotting. This is the base plotting function,
    see plot_mutation_matrix() for more convenient access.

    Parameters
    ----------
    matrix : np.array(float)
        2D numpy array with values for individual single mutations
        (first axis: position, second axis: substitution)
    positions : list(int) or list(str)
        List of positions along x-axis of matrix
        (length has to agree with first dimension of matrix)
    substitutions : list(str)
        List of substitutions along y-axis of matrix
        (length has to agree with second dimension of matrix)
    conservation : list(float) or np.array(float), optional (default: None)
        Positional conservation along sequence. Values must range
        between 0 (not conserved) and 1 (fully conserved). If given,
        will plot conservation along bottom of mutation matrix.
    secondary_structure : str or list(str), optional (default: None)
        Secondary structure for each position along sequence. If given,
        will draw secondary structure cartoon on top of matrix.
    wildtype_sequence : str or list(str), optional (default: None)
        Sequence of wild-type symbols. If given, will indicate wild-type
        entries in matrix with a dot.
    min_value : float, optional (default: None)
        Threshold colormap at this minimum value. If None, defaults to
        minimum value in matrix; if max_value is also None, defaults to
        -max(abs(matrix))
    max_value : float, optional (default: None)
        Threshold colormap at this maximum value. If None, defaults to
        maximum value in matrix; if min_value is also None, defaults to
        max(abs(matrix))
    ax : Matplotlib axes object, optional (default: None)
        Draw mutation matrix on this axis. If None, new figure and axis
        will be created.
    colormap : matplotlib colormap object, optional (default: plt.cm.RdBu_r)
        Maps mutation effects to colors of matrix cells.
    colormap_conservation: matplotlib colormap object, optional (default: plt.cm.Oranges)
        Maps sequence conservation to colors of conservation vector plot.
    na_color : str, optional (default: "#bbbbbb")
        Color for missing values in matrix
    title : str, optional (default: None)
        If given, set title of plot to this value.
    position_label_size : int, optional (default: 8)
        Font size of x-axis labels.
    substitution_label_size : int, optional (default: 8)
        Font size of y-axis labels.
    show_colorbar : bool, optional (default: True)
        If True, show colorbar next to matrix.
    colorbar_indicate_bounds : bool, optional (default: False)
        If True, add greater-than/less-than signs to limits of colorbar
        to indicate that colors were thresholded at min_value/max_value
    show_wt_char : bool, optional (default: True)
        Display wild-type symbol in axis labels
    label_filter : function, optional (default: None)
        Function with one argument (integer) that determines if a certain position
        label will be printed (if label_filter(pos)==True) or not.
    secondary_structure_style : dict, optional (default: None)
        Pass on as **kwargs to evcouplings.visualize.pairs.secondary_structure_cartoon
        to determine appearance of secondary structure cartoon.

    Returns
    -------
    ax : Matplotlib axes object
        Axes on which mutation matrix was drawn
    """
    LINEWIDTH = 0.0
    LABEL_X_OFFSET = 0.55
    LABEL_Y_OFFSET = 0.45

    def _draw_rect(x_range, y_range, linewidth):
        r = plt.Rectangle((min(x_range), min(y_range)),
                          max(x_range) - min(x_range),
                          max(y_range) - min(y_range),
                          fc='None',
                          linewidth=linewidth)
        ax.add_patch(r)

    matrix_width = matrix.shape[0]
    matrix_height = len(substitutions)

    # mask NaN entries in mutation matrix
    matrix_masked = np.ma.masked_where(np.isnan(matrix), matrix)

    # figure out maximum and minimum values for color map
    if max_value is None and min_value is None:
        max_value = np.abs(matrix_masked).max()
        min_value = -max_value
    elif min_value is None:
        min_value = matrix_masked.min()
    elif max_value is None:
        max_value = matrix_masked.max()

    # set NaN color value in colormaps
    colormap = deepcopy(colormap)
    colormap.set_bad(na_color)
    colormap_conservation = deepcopy(colormap_conservation)
    colormap_conservation.set_bad(na_color)

    # determine size of plot (depends on how much tracks
    # with information we will add)
    num_rows = (len(substitutions) + (conservation is not None) +
                (secondary_structure is not None))

    ratio = matrix_width / float(num_rows)

    # create axis, if not given
    if ax is None:
        fig = plt.figure(figsize=(ratio * 5, 5))
        ax = fig.gca()

    # make square-shaped matrix cells
    ax.set_aspect("equal", "box")

    # define matrix coordinates
    # always add +1 because coordinates are used by
    # pcolor(mesh) as beginning and start of rectangles
    x_range = np.array(range(matrix_width + 1))
    y_range = np.array(range(matrix_height + 1))
    y_range_avg = range(-2, 0)
    x_range_avg = range(matrix_width + 1, matrix_width + 3)
    y_range_cons = np.array(y_range_avg) - 1.5

    # coordinates for text labels (fixed axis)
    x_left_subs = min(x_range) - 1
    x_right_subs = max(x_range_avg) + 1

    if conservation is None:
        y_bottom_res = min(y_range_avg) - 0.5
    else:
        y_bottom_res = min(y_range_cons) - 0.5

    # coordinates for additional annotation
    y_ss = max(y_range) + 2

    # 1) main mutation matrix
    X, Y = np.meshgrid(x_range, y_range)
    cm = ax.pcolormesh(X,
                       Y,
                       matrix_masked.T,
                       cmap=colormap,
                       vmax=max_value,
                       vmin=min_value)
    _draw_rect(x_range, y_range, LINEWIDTH)

    # 2) mean column effect (bottom "subplot")
    mean_pos = np.mean(matrix_masked, axis=1)[:, np.newaxis]
    X_pos, Y_pos = np.meshgrid(x_range, y_range_avg)
    ax.pcolormesh(X_pos,
                  Y_pos,
                  mean_pos.T,
                  cmap=colormap,
                  vmax=max_value,
                  vmin=min_value)
    _draw_rect(x_range, y_range_avg, LINEWIDTH)

    # 3) amino acid average (right "subplot")
    mean_aa = np.mean(matrix_masked, axis=0)[:, np.newaxis]
    X_aa, Y_aa = np.meshgrid(x_range_avg, y_range)
    ax.pcolormesh(X_aa,
                  Y_aa,
                  mean_aa,
                  cmap=colormap,
                  vmax=max_value,
                  vmin=min_value)
    _draw_rect(x_range_avg, y_range, LINEWIDTH)

    # mark wildtype residues
    if wildtype_sequence is not None:
        subs_list = list(substitutions)

        for i, wt in enumerate(wildtype_sequence):
            # skip unspecified entries
            if wt is not None and wt != "":
                marker = plt.Circle(
                    (x_range[i] + 0.5, y_range[subs_list.index(wt)] + 0.5),
                    0.1,
                    fc='k',
                    axes=ax)
                ax.add_patch(marker)

    # put labels along both axes of matrix

    # x-axis (positions)
    for i, pos in zip(x_range, positions):
        # filter labels, if selected
        if label_filter is not None and not label_filter(pos):
            continue

        # determine what position label should be
        if show_wt_char and wildtype_sequence is not None:
            wt_symbol = wildtype_sequence[i]
            if type(pos) is tuple and len(pos) == 2:
                # label will be in format segment AA pos, eg B_1 A 151
                label = "{} {} {}".format(pos[0], wt_symbol, pos[1])
            else:
                label = "{} {}".format(wt_symbol, pos)

        else:
            if type(pos) is tuple:
                label = " ".join(map(str, pos))
            else:
                label = str(pos)

        ax.text(i + LABEL_X_OFFSET,
                y_bottom_res,
                label,
                size=position_label_size,
                horizontalalignment='center',
                verticalalignment='top',
                rotation=90)

    # y-axis (substitutions)
    for j, subs in zip(y_range, substitutions):
        # put on lefthand side of matrix...
        ax.text(x_left_subs,
                j + LABEL_Y_OFFSET,
                subs,
                size=substitution_label_size,
                horizontalalignment='center',
                verticalalignment='center')

        # ...and on right-hand side of matrix
        ax.text(x_right_subs,
                j + LABEL_Y_OFFSET,
                subs,
                size=substitution_label_size,
                horizontalalignment='center',
                verticalalignment='center')

    # draw colorbar
    if show_colorbar:
        cb = plt.colorbar(cm,
                          ticks=[min_value, max_value],
                          shrink=0.3,
                          pad=0.15 / ratio,
                          aspect=8)

        if colorbar_indicate_bounds:
            symbol_min, symbol_max = u"\u2264", u"\u2265"
        else:
            symbol_min, symbol_max = "", ""

        cb.ax.set_yticklabels([
            "{symbol} {value:>+{width}.1f}".format(symbol=s, value=v, width=0)
            for (v, s) in [(min_value, symbol_min), (max_value, symbol_max)]
        ])
        cb.ax.xaxis.set_ticks_position("none")
        cb.ax.yaxis.set_ticks_position("none")
        cb.outline.set_linewidth(0)

    # plot secondary structure cartoon
    if secondary_structure is not None:
        # if no style given for secondary structure, set default
        if secondary_structure_style is None:
            secondary_structure_style = {
                "width": 0.8,
                "line_width": 2,
                "strand_width_factor": 0.5,
                "helix_turn_length": 2,
                "min_sse_length": 2,
            }

        start, end, sse = find_secondary_structure_segments(
            secondary_structure)
        secondary_structure_cartoon(sse,
                                    sequence_start=start,
                                    sequence_end=end,
                                    center=y_ss,
                                    ax=ax,
                                    **secondary_structure_style)

    # plot conservation
    if conservation is not None:
        conservation = np.array(conservation)[:, np.newaxis]
        cons_masked = np.ma.masked_where(np.isnan(conservation), conservation)
        X_cons, Y_cons = np.meshgrid(x_range, y_range_cons)
        ax.pcolormesh(X_cons,
                      Y_cons,
                      cons_masked.T,
                      cmap=colormap_conservation,
                      vmax=1,
                      vmin=0)
        _draw_rect(x_range, y_range_cons, LINEWIDTH)

    # remove chart junk
    for line in ['top', 'bottom', 'right', 'left']:
        ax.spines[line].set_visible(False)

    ax.xaxis.set_ticks_position("none")
    ax.yaxis.set_ticks_position("none")
    plt.setp(ax.get_xticklabels(), visible=False)
    plt.setp(ax.get_yticklabels(), visible=False)

    if title is not None:
        ax.set_title(title)

    return ax
Пример #2
0
def dihedral_ranking_score(structure,
                           residues,
                           sec_struct_column="sec_struct_3state",
                           original=True):
    """
    Assess quality of structure model by twist of 
    predicted alpha-helices and beta-sheets.

    This function re-implements the functionality of
    make_alpha_beta_score_table.m from the original
    pipeline.

    Parameters
    ----------
    structure : evcouplings.compare.pdb.Chain
        Chain with 3D structure coordinates to evaluate
    residues : pandas.DataFrame
        Residue table with secondary structure predictions
        (columns i, A_i and secondary structure column)
    sec_struct_column : str, optional (default: sec_struct_3state)
        Column in residues dataframe that contains predicted
        secondary structure (H, E, C)
    original : bool, optional (default: True):
        Use exact implementation of 2011 ranking protocol

    Returns
    -------
    int
        Number of alpha dihedrals computed
    float
        Alpha dihedral score (unnormalized)
    int
        Number of beta dihedrals computed
    float
        Beta dihedral score (unnormalized)
    """
    # create table that, for each residue, contains
    # secondary structure and C_alpha coordinates.

    # First, throw away anything but C_alpha atoms
    structure = structure.filter_atoms(atom_name="CA")

    # Then, merge residue with atom information
    x = structure.residues.merge(structure.coords,
                                 left_index=True,
                                 right_on="residue_index")

    # Then, merge with secondary structure prediction
    # PDB indices are strings, so merge on string
    residues = residues.copy()
    residues.loc[:, "id"] = residues.i.astype(str)
    x = residues.merge(x, on="id", how="left", suffixes=("", "_"))

    # find secondary structure segments
    _, _, segments = find_secondary_structure_segments("".join(
        x.loc[:, sec_struct_column]),
                                                       offset=x.i.min())

    def _get_segments(seg_type):
        return [(start, end) for (type_, start, end) in segments
                if type_ == seg_type]

    segs_alpha = _get_segments("H")
    segs_beta = _get_segments("E")

    # extract positions that actually have C_alpha coordinates
    x_valid = x.dropna(subset=["x", "y", "z"])

    # compute alpha helix and beta sheet dihedrals
    d_alpha = _alpha_dihedrals(x_valid, segs_alpha)
    d_beta = _beta_dihedrals(x_valid, segs_beta, original=original)

    # Finally, merge results into score

    # first, count how many angles in the right range
    # we have, but weight for actual value of the dihedral

    # alpha
    alpha_weights = [
        (0.2, 0.44, 0.52),
        (0.4, 0.52, 0.61),
        (0.6, 0.61, 0.70),
        (0.8, 0.70, 0.78),
        (1.0, 0.78, 0.96),
        (0.8, 0.96, 1.05),
        (0.6, 1.05, 1.13),
        (0.4, 1.13, 1.22),
        (0.2, 1.22, 1.31),
    ]

    if len(d_alpha) > 0:
        alpha_dihedral_score = sum([
            weight *
            len(d_alpha.query("@lower < dihedral and dihedral <= @upper"))
            for weight, lower, upper in alpha_weights
        ])
    else:
        alpha_dihedral_score = 0

    # beta
    beta_weights = [
        (0.2, -0.3, -0.1),
        (0.4, -0.4, -0.3),
        (0.6, -0.5, -0.4),
        (0.8, -0.6, -0.5),
        (1.0, -0.8, -0.6),
        (0.8, -0.9, -0.8),
        (0.6, -1.0, -0.9),
        (0.4, -1.1, -1.0),
        (0.2, -1.2, -1.1),
    ]

    if len(d_beta) > 0:
        beta_dihedral_score = sum([
            weight *
            len(d_beta.query("@lower <= dihedral and dihedral < @upper"))
            for weight, lower, upper in beta_weights
        ])
    else:
        beta_dihedral_score = 0

    return len(d_alpha), alpha_dihedral_score, len(d_beta), beta_dihedral_score