Пример #1
0
def plot_Cnorm(cnorm,
               labels,
               Ptrue=[0.1, 0.5],
               ax=None,
               title=None,
               fontsize=12):
    from matplotlib import pyplot as plt
    cmap = plt.cm.Blues
    cnorm = cnorm.astype('float32')
    if not isinstance(Ptrue, (tuple, list, np.ndarray)):
        Ptrue = (Ptrue, )
    Ptrue = [float(i) for i in Ptrue]
    if len(Ptrue) != cnorm.shape[0]:
        raise ValueError(
            "`Cnorm` was calculated for %d Ptrue values, but given only "
            "%d values for `Ptrue`: %s" %
            (cnorm.shape[0], len(Ptrue), str(Ptrue)))
    ax = to_axis(ax, is_3D=False)
    ax.imshow(cnorm, interpolation='nearest', cmap=cmap)
    # axis.get_figure().colorbar(im)
    ax.set_xticks(np.arange(len(labels)))
    ax.set_yticks(np.arange(len(Ptrue)))
    ax.set_xticklabels(labels, rotation=-57, fontsize=fontsize)
    ax.set_yticklabels([str(i) for i in Ptrue], fontsize=fontsize)
    ax.set_ylabel('Ptrue', fontsize=fontsize)
    ax.set_xlabel('Predicted label', fontsize=fontsize)
    # center text for value of each grid
    for i, j in itertools.product(range(len(Ptrue)), range(len(labels))):
        color = 'red'
        weight = 'normal'
        fs = fontsize
        text = '%.2f' % cnorm[i, j]
        plt.text(j,
                 i,
                 text,
                 weight=weight,
                 color=color,
                 fontsize=fs,
                 verticalalignment="center",
                 horizontalalignment="center")
    # Turns off grid on the left Axis.
    ax.grid(False)
    title = "Cnorm: %.6f" % np.mean(cnorm) if title is None else \
    "%s (Cnorm: %.6f)" % (str(title), np.mean(cnorm))
    ax.set_title(title, fontsize=fontsize + 2, weight='semibold')
    # axis.tight_layout()
    return ax
Пример #2
0
def plot_distance_heatmap(X,
                          labels,
                          lognorm=True,
                          colormap='hot',
                          ax=None,
                          legend_enable=True,
                          legend_loc='upper center',
                          legend_ncol=3,
                          legend_colspace=0.2,
                          fontsize=10,
                          cbar=True,
                          title=None):
    r"""

  Arguments:
    X : (n_samples, n_features). Coordination for scatter points
    labels : (n_samples,). List of classes index or name
  """
    import seaborn as sns
    from matplotlib import pyplot as plt
    from matplotlib.lines import Line2D
    from odin import backend as K

    # prepare data
    X = K.length_norm(X, axis=-1, epsilon=np.finfo(X.dtype).eps)
    ax = to_axis(ax)
    n_samples, n_dim = X.shape
    # processing labels
    labels = np.array(labels).ravel()
    assert labels.shape[0] == n_samples, "labels must be 1-D array."
    is_continuous = isinstance(labels[0],
                               Number) and int(labels[0]) != labels[0]
    # float values label (normalize -1 to 1) or binary classification
    if is_continuous:
        min_val = np.min(labels)
        max_val = np.max(labels)
        labels = 2 * (labels - min_val) / (max_val - min_val) - 1
        n_labels = 2
        labels_name = {'-1': 0, '+1': 1}
    else:
        labels_name = {name: i for i, name in enumerate(np.unique(labels))}
        n_labels = len(labels_name)
        labels = np.array([labels_name[name] for name in labels])
    # ====== sorting label and X ====== #
    order_X = np.vstack(
        [x for _, x in sorted(zip(labels, X), key=lambda pair: pair[0])])
    order_label = np.vstack(
        [y for y, x in sorted(zip(labels, X), key=lambda pair: pair[0])])
    distance = sp.spatial.distance_matrix(order_X, order_X)
    if bool(lognorm):
        distance = np.log1p(distance)
    min_non_zero = np.min(distance[np.nonzero(distance)])
    distance = np.clip(distance, a_min=min_non_zero, a_max=np.max(distance))
    # ====== convert data to image ====== #
    cm = plt.get_cmap(colormap)
    distance_img = cm(distance)
    # diagonal black line (i.e. zero distance)
    # for i in range(n_samples):
    #   distance_img[i, i] = (0, 0, 0, 1)
    # labels colormap
    width = max(int(0.032 * n_samples), 8)
    if n_labels == 2:
        cm = plt.get_cmap('bwr')
        horz_bar = np.repeat(cm(order_label.T), repeats=width, axis=0)
        vert_bar = np.repeat(cm(order_label), repeats=width, axis=1)
        all_colors = np.array((cm(np.min(labels)), cm(np.max(labels))))
    else:  # use seaborn color palette here is better
        cm = [i + (1., ) for i in sns.color_palette(n_colors=n_labels)]
        c = np.stack([cm[i] for i in order_label.ravel()])
        horz_bar = np.repeat(np.expand_dims(c, 0), repeats=width, axis=0)
        vert_bar = np.repeat(np.expand_dims(c, 1), repeats=width, axis=1)
        all_colors = cm
    # image
    final_img = np.zeros(shape=(n_samples + width, n_samples + width,
                                distance_img.shape[2]),
                         dtype=distance_img.dtype)
    final_img[width:, width:] = distance_img
    final_img[:width, width:] = horz_bar
    final_img[width:, :width] = vert_bar
    assert np.sum(final_img[:width, :width]) == 0, \
    "Something wrong with my spacial coordination when writing this code!"
    # ====== plotting ====== #
    ax.imshow(final_img)
    ax.axis('off')
    # ====== legend ====== #
    if bool(legend_enable):
        legend_elements = [
            Line2D([0], [0],
                   marker='o',
                   color=color,
                   label=name,
                   linewidth=0,
                   linestyle=None,
                   lw=0,
                   markerfacecolor=color,
                   markersize=fontsize // 2)
            for color, name in zip(all_colors, labels_name.keys())
        ]
        ax.legend(handles=legend_elements,
                  markerscale=1.,
                  scatterpoints=1,
                  scatteryoffsets=[0.375, 0.5, 0.3125],
                  loc=legend_loc,
                  bbox_to_anchor=(0.5, -0.01),
                  ncol=int(legend_ncol),
                  columnspacing=float(legend_colspace),
                  labelspacing=0.,
                  fontsize=fontsize - 1,
                  handletextpad=0.1)
    # ====== final configurations ====== #
    if title is not None:
        ax.set_title(str(title), fontsize=fontsize)
    if cbar:
        from odin.visual import plot_colorbar
        plot_colorbar(colormap,
                      vmin=np.min(distance),
                      vmax=np.max(distance),
                      ax=ax,
                      orientation='vertical')
    return ax
Пример #3
0
def plot_scatter_layers(x_y_val,
                        ax=None,
                        layer_name=None,
                        layer_color=None,
                        layer_marker=None,
                        size=4.0,
                        z_ratio=4,
                        elev=None,
                        azim=88,
                        ticks_off=True,
                        grid=True,
                        surface=True,
                        wireframe=False,
                        wireframe_resolution=10,
                        colorbar=False,
                        colorbar_horizontal=False,
                        legend_loc='upper center',
                        legend_ncol=3,
                        legend_colspace=0.4,
                        fontsize=8,
                        title=None):
  r"""
  Parameter
  ---------
  z_ratio: float (default: 4)
    the amount of compression that layer in z_axis will be closer
    to each others compared to (x, y) axes
  """
  from matplotlib import pyplot as plt
  assert len(x_y_val) > 1, "Use `plot_scatter_heatmap` to plot only 1 layer"
  max_z = -np.inf
  min_z = np.inf
  for x, y, val in x_y_val:
    assert len(x) == len(y) == len(val), "Number of samples mismatch"
    max_z = max(max_z, np.max(x), np.max(y))
    min_z = min(min_z, np.min(x), np.min(y))
  ax = to_axis(ax, is_3D=True)
  num_classes = len(x_y_val)
  # ====== preparing ====== #
  # name
  layer_name = check_arg_length(dat=layer_name,
                                n=num_classes,
                                dtype=string_types,
                                default='',
                                converter=lambda x: str(x))
  # colormap
  layer_color = check_arg_length(dat=layer_color,
                                 n=num_classes,
                                 dtype=string_types,
                                 default='Blues',
                                 converter=lambda x: plt.get_cmap(str(x)))
  # class marker
  layer_marker = check_arg_length(dat=layer_marker,
                                  n=num_classes,
                                  dtype=string_types,
                                  default='o',
                                  converter=lambda x: str(x))
  # size
  size = check_arg_length(dat=size,
                          n=num_classes,
                          dtype=Number,
                          default=4.0,
                          converter=lambda x: float(x))
  # ====== plotting each class ====== #
  legends = []
  for idx, (alpha, z) in enumerate(
      zip(np.linspace(0.05, 0.4, num_classes),
          np.linspace(min_z / 4, max_z / 4, num_classes))):
    x, y, val = x_y_val[idx]
    num_samples = len(x)
    z = np.full(shape=(num_samples,), fill_value=z)
    _ = ax.scatter(x,
                   y,
                   z,
                   c=val,
                   s=size[idx],
                   marker=layer_marker[idx],
                   cmap=layer_color[idx])
    # ploting surface and wireframe
    if surface or wireframe:
      x, y = np.meshgrid(np.linspace(min(x), max(x), wireframe_resolution),
                         np.linspace(min(y), max(y), wireframe_resolution))
      z = np.full_like(x, fill_value=z[0])
      if surface:
        ax.plot_surface(X=x,
                        Y=y,
                        Z=z,
                        color=layer_color[idx](0.5),
                        edgecolor='none',
                        alpha=alpha)
      if wireframe:
        ax.plot_wireframe(X=x,
                          Y=y,
                          Z=z,
                          linewidth=0.8,
                          color=layer_color[idx](0.8),
                          alpha=alpha + 0.1)
    # legend
    name = layer_name[idx]
    if len(name) > 0:
      legends.append((name, _))
    # colorbar
    if colorbar:
      cba = plt.colorbar(
        _,
        shrink=0.5,
        pad=0.01,
        orientation='horizontal' if colorbar_horizontal else 'vertical')
      if len(name) > 0:
        cba.set_label(name, fontsize=fontsize)
  # ====== plot the legend ====== #
  if len(legends) > 0:
    legends = ax.legend([i[1] for i in legends], [i[0] for i in legends],
                        markerscale=1.5,
                        scatterpoints=1,
                        scatteryoffsets=[0.375, 0.5, 0.3125],
                        loc=legend_loc,
                        bbox_to_anchor=(0.5, -0.01),
                        ncol=int(legend_ncol),
                        columnspacing=float(legend_colspace),
                        labelspacing=0.,
                        fontsize=fontsize,
                        handletextpad=0.1)
    for i, c in enumerate(layer_color):
      legends.legendHandles[i].set_color(c(.8))
  # ====== some configuration ====== #
  if ticks_off:
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
  ax.grid(grid)
  if title is not None:
    ax.set_title(str(title))
  if (elev is not None or azim is not None):
    ax.view_init(elev=ax.elev if elev is None else elev,
                 azim=ax.azim if azim is None else azim)
  return ax
Пример #4
0
def plot_heatmap(data,
                 cmap="Blues",
                 ax=None,
                 xticklabels=None,
                 yticklabels=None,
                 xlabel=None,
                 ylabel=None,
                 cbar_title=None,
                 cbar=False,
                 fontsize=12,
                 gridline=0,
                 hide_spines=True,
                 annotation=None,
                 text_colors=dict(diag="black",
                                  minrow=None,
                                  mincol=None,
                                  maxrow=None,
                                  maxcol=None,
                                  other="black"),
                 title=None):
    r""" Showing heatmap matrix """
    from matplotlib import pyplot as plt
    ax = to_axis(ax, is_3D=False)
    ax.grid(False)
    fig = ax.get_figure()
    # figsize = fig.get_size_inches()
    # prepare labels
    if xticklabels is None and yticklabels is not None:
        xticklabels = ["X#%d" % i for i in range(data.shape[1])]
    if yticklabels is None and xticklabels is not None:
        yticklabels = ["Y#%d" % i for i in range(data.shape[0])]
    # Plot the heatmap
    im = ax.imshow(data,
                   interpolation='nearest',
                   cmap=cmap,
                   aspect='equal',
                   origin='upper')
    # Create colorbar
    if cbar:
        cb = plt.colorbar(im, fraction=0.02, pad=0.02)
        if cbar_title is not None:
            cb.ax.set_ylabel(cbar_title,
                             rotation=-90,
                             va="bottom",
                             fontsize=fontsize)
    ## major ticks
    if xticklabels is not None and yticklabels is not None:
        ax.set_xticks(np.arange(data.shape[1]))
        ax.set_xticklabels(xticklabels, fontsize=fontsize)
        ax.set_yticks(np.arange(data.shape[0]))
        ax.set_yticklabels(list(yticklabels), fontsize=fontsize)
        # Let the horizontal axes labeling appear on top.
        ax.tick_params(top=True,
                       bottom=False,
                       labeltop=True,
                       labelbottom=False)
        # Rotate the tick labels and set their alignment.
        plt.setp(ax.get_xticklabels(),
                 rotation=-30,
                 ha="right",
                 rotation_mode="anchor")
    else:  # turn-off all ticks
        ax.tick_params(top=False,
                       bottom=False,
                       labeltop=False,
                       labelbottom=False)
    ## axis label
    if ylabel is not None:
        ax.set_ylabel(ylabel, fontsize=fontsize + 1)
    if xlabel is not None:
        ax.set_xlabel(xlabel, fontsize=fontsize + 1)
    ## Turn spines off
    if hide_spines:
        for edge, spine in ax.spines.items():
            spine.set_visible(False)
    ## minor ticks and create white grid.
    # (if no minor ticks, the image will be cut-off)
    ax.set_xticks(np.arange(data.shape[1] + 1) - .5, minor=True)
    ax.set_yticks(np.arange(data.shape[0] + 1) - .5, minor=True)
    if gridline > 0:
        ax.grid(which="minor", color="w", linestyle='-', linewidth=gridline)
        ax.tick_params(which="minor", bottom=False, left=False)
    # set the title
    if title is not None:
        ax.set_title(str(title), fontsize=fontsize + 2, weight='semibold')
    # prepare the annotation
    if annotation is not None and annotation is not False:
        if annotation is True:
            annotation = np.array([['%.2g' % x for x in row] for row in data])
        assert annotation.shape == data.shape
        kw = dict(horizontalalignment="center",
                  verticalalignment="center",
                  fontsize=fontsize)
        # np.log(max(2, np.mean(data.shape) - np.mean(figsize)))
        # Loop over the data and create a `Text` for each "pixel".
        # Change the text's color depending on the data.
        texts = []
        minrow = text_colors.get('minrow', None)
        maxrow = text_colors.get('maxrow', None)
        mincol = text_colors.get('mincol', None)
        maxcol = text_colors.get('maxcol', None)
        for i in range(data.shape[0]):
            for j in range(data.shape[1]):
                # basics text config
                if i == j:
                    kw['weight'] = 'bold'
                    color = text_colors.get('diag', 'black')
                else:
                    kw['weight'] = 'normal'
                    color = text_colors.get('other', 'black')
                # min, max of row
                if data[i, j] == min(data[i]) and minrow is not None:
                    color = minrow
                elif data[i, j] == max(data[i]) and maxrow is not None:
                    color = maxrow
                # min, max of column
                if data[i, j] == min(data[:, j]) and mincol is not None:
                    color = mincol
                elif data[i, j] == max(data[:, j]) and maxcol is not None:
                    color = maxcol
                # show text
                text = im.axes.text(j, i, annotation[i, j], color=color, **kw)
                texts.append(text)
    return ax
Пример #5
0
def _plot_scatter_points(*, x, y, z, val, color, marker, size, size_range,
                         alpha, max_n_points, cbar, cbar_horizontal,
                         cbar_nticks, cbar_ticks_rotation, cbar_title,
                         cbar_fontsize, legend_enable, legend_loc, legend_ncol,
                         legend_colspace, elev, azim, ticks_off, grid, fontsize,
                         centroids, xlabel, ylabel, title, ax, **kwargs):
  from matplotlib import pyplot as plt
  import matplotlib as mpl
  # keep the marker as its original text
  text_marker = kwargs.get('text_marker', False)
  x, y, z = _parse_scatterXYZ(x, y, z)
  assert len(x) == len(y), "Number of samples mismatch"
  if z is not None:
    assert len(y) == len(z)
  is_3D_mode = False if z is None else True
  ax = to_axis(ax, is_3D_mode)
  ### check the colormap
  if val is None:
    vmin, vmax, color_normalizer = None, None, None
    is_colormap = False
  else:
    from matplotlib.colors import LinearSegmentedColormap
    vmin = np.min(val)
    vmax = np.max(val)
    color_normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    is_colormap = True
    if is_colormap:
      assert isinstance(color, (string_types, LinearSegmentedColormap)), \
        "`colormap` can be string or instance of matplotlib Colormap, " + \
        "but given: %s" % type(color)
  if not is_colormap and isinstance(color, string_types) and color == 'bwr':
    color = 'b'
  ### perform downsample and select the styles
  max_n_points, x, y, z, color, marker, size = _downsample_scatter_points(
    x, y, z, max_n_points, color, marker, size)
  color, marker, size, legend = _validate_color_marker_size_legend(
    max_n_points,
    color,
    marker,
    size,
    text_marker=text_marker,
    is_colormap=is_colormap,
    size_range=size_range)
  ### centroid style
  centroid_style = dict(horizontalalignment='center',
                        verticalalignment='center',
                        fontsize=fontsize + 2,
                        weight="bold",
                        bbox=dict(boxstyle="circle",
                                  facecolor="black",
                                  alpha=0.48,
                                  pad=0.,
                                  edgecolor='none'))
  ### plotting
  artist = []
  legend_name = []
  for plot_idx, (style, name) in enumerate(legend.items()):
    style = list(style)
    x_, y_, z_, val_ = [], [], [], []
    # get the right set of data points
    for i, (c, m, s) in enumerate(zip(color, marker, size)):
      if c == style[0] and m == style[1] and s == style[2]:
        x_.append(x[i])
        y_.append(y[i])
        if is_colormap:
          val_.append(val[i])
        if is_3D_mode:
          z_.append(z[i])
    # 2D or 3D plot
    if not is_3D_mode:
      z_ = None
    # colormap or normal color
    if not is_colormap:
      val_ = None
    else:
      cm = plt.cm.get_cmap(style[0])
      val_ = color_normalizer(val_)
      style[0] = cm(val_)
    # yield for plotting
    n_art = len(artist)
    yield ax, artist, x_, y_, z_, style
    # check new axis added
    assert len(artist) > n_art, \
      "Forgot adding new art object created by plotting"
    # check if ploting centroid
    if centroids:
      if is_3D_mode:
        ax.text(np.mean(x_),
                np.mean(y_),
                np.mean(z_),
                s=name[0],
                color=style[0],
                **centroid_style)
      else:
        ax.text(np.mean(x_),
                np.mean(y_),
                s=name[0],
                color=style[0],
                **centroid_style)
    # make the shortest name
    name = [i for i in name if len(i) > 0]
    short_name = []
    for i in name:
      if i not in short_name:
        short_name.append(i)
    name = ', '.join(short_name)
    if len(name) > 0:
      legend_name.append(name)
  ### at the end of the iteration, axis configuration
  if len(artist) == len(legend):
    ## colorbar (only enable when colormap is provided)
    if is_colormap and cbar:
      mappable = plt.cm.ScalarMappable(norm=color_normalizer, cmap=cm)
      mappable.set_clim(vmin, vmax)
      cba = plt.colorbar(
        mappable,
        ax=ax,
        shrink=0.99,
        pad=0.01,
        orientation='horizontal' if cbar_horizontal else 'vertical')
      if isinstance(cbar_nticks, Number):
        cbar_range = np.linspace(vmin, vmax, num=int(cbar_nticks))
        cbar_nticks = [f'{i:.2g}' for i in cbar_range]
      elif isinstance(cbar_nticks, (tuple, list, np.ndarray)):
        cbar_range = np.linspace(vmin, vmax, num=len(cbar_nticks))
        cbar_nticks = [str(i) for i in cbar_nticks]
      else:
        raise ValueError(f"No support for cbar_nticks='{cbar_nticks}'")
      cba.set_ticks(cbar_range)
      cba.set_ticklabels(cbar_nticks)
      if cbar_title is not None:
        if cbar_horizontal:  # horizontal colorbar
          cba.ax.set_xlabel(str(cbar_title), fontsize=cbar_fontsize)
        else:  # vertical colorbar
          cba.ax.set_ylabel(str(cbar_title), fontsize=cbar_fontsize)
      cba.ax.tick_params(labelsize=cbar_fontsize,
                         labelrotation=cbar_ticks_rotation)
    ## plot the legend
    if len(legend_name) > 0 and bool(legend_enable):
      markerscale = 1.5
      if isinstance(artist[0], mpl.text.Text):  # text plot special case
        for i, art in enumerate(list(artist)):
          pos = [art._x, art._y]
          if is_3D_mode:
            pos.append(art._z)
          if is_colormap:
            c = art._color
          else:
            c = art._color
          artist[i] = ax.scatter(*pos, c=c, s=0.1)
          markerscale = 25
      # sort the legends
      legend_name, artist = zip(
        *sorted(zip(legend_name, artist), key=lambda t: t[0]))
      legend_kw = {}
      if legend_loc is not None:
        legend_kw['loc'] = legend_loc
      if legend_ncol is not None:
        legend_kw['ncol'] = legend_ncol
      legend = ax.legend(artist, legend_name,
                         labelspacing=0.,
                         handletextpad=0.1,
                         markerscale=markerscale,
                         scatterpoints=1,
                         columnspacing=float(legend_colspace),
                         fontsize=fontsize,
                         **legend_kw)
      # scatteryoffsets=[0.375, 0.5, 0.3125],
      # bbox_to_anchor=(0.5, -0.01),
      # labelspacing=0.,
      # handletextpad=0.1)
    ## tick configuration
    if ticks_off:
      ax.set_xticklabels([])
      ax.set_yticklabels([])
      if is_3D_mode:
        ax.set_zticklabels([])
    if grid:
      ax.set_axisbelow(True)
      ax.grid(grid, which='both', axis='both', linewidth=0.8, alpha=0.5)
    if xlabel is not None:
      ax.set_xlabel(str(xlabel), fontsize=fontsize - 1)
    if ylabel is not None:
      ax.set_ylabel(str(ylabel), fontsize=fontsize - 1)
    if title is not None:
      ax.set_title(str(title), fontsize=fontsize, fontweight='regular')
    if is_3D_mode and (elev is not None or azim is not None):
      ax.view_init(elev=ax.elev if elev is None else elev,
                   azim=ax.azim if azim is None else azim)