Esempio n. 1
0
def plot_line(xdata,
              ydata,
              title,
              xlbl,
              ylbl,
              out_name,
              color,
              lbl,
              stdev=None,
              do_save=True,
              black_bg=False):

    #####
    # Plot
    #####

    plt.plot(xdata, ydata, 'o', markersize=5, markeredgewidth=0, color=color)
    #hdl, =
    plt.plot(xdata, ydata, '-', linewidth=2, color=color, label=lbl)

    # Plot error bars
    #   http://matplotlib.org/1.2.1/examples/pylab_examples/errorbar_demo.html
    #   Plot error bar without line, fmt='':
    #   http://stackoverflow.com/questions/18498742/how-do-you-make-an-errorbar-plot-in-matplotlib-using-linestyle-none-in-rcparams
    if stdev:
        plt.errorbar(xdata, ydata, fmt='', capthick=2, yerr=stdev, color=color)

    plt.grid(True, color='gray')
    #plt.grid (False)

    if title:
        plt.title(title)
    plt.xlabel(xlbl)
    plt.ylabel(ylbl)

    if black_bg:
        black_background()

    # Save to file
    if do_save:
        plt.savefig(out_name, bbox_inches='tight')
        print('Plot saved to %s' % out_name)
def plot_hist(
        hist,
        edges,
        ax,
        ylbl,
        strict_lims=True,  #, xlbl_suffix=''):
        bg_color='white',
        fg_color='g',
        tick_rot=0):

    # MATLAB "hold on"
    ax.hold(True)
    # MATLAB "grid on"
    ax.grid(True, color='gray')

    # Ref font size of labels: http://stackoverflow.com/questions/12444716/how-do-i-set-figure-title-and-axes-labels-font-size-in-matplotlib
    #ax.set_xlabel ('Bins' + xlbl_suffix, fontsize=6)
    ax.set_ylabel(ylbl)

    if strict_lims:
        ax.set_xlim(edges[0], edges[len(edges) - 1])

    # Get histogram centers from edges. n+1 edges means there are n centers.
    centers = (edges[0:len(edges) - 1] + edges[1:len(edges)]) * 0.5

    width = (edges[1] - edges[0]) * .5

    #print (edges)
    #print (centers)
    #print (hist)
    ax.bar(centers, hist, width=width, color=fg_color, edgecolor='none')

    # Rotate tick orientation
    # Diagonal 45-degree slanted tick labels
    # Ref: https://stackoverflow.com/questions/14852821/aligning-rotated-xticklabels-with-their-respective-xticks/14854007#14854007
    # axes API: https://matplotlib.org/api/axes_api.html
    if tick_rot != 0:
        ax.set_xticklabels(ax.get_xticks(), rotation=tick_rot)

    if bg_color == 'black':
        black_background(ax=ax)
Esempio n. 3
0
def main():

    arg_parser = argparse.ArgumentParser()

    # Variable number of args http://stackoverflow.com/questions/13219910/argparse-get-undefined-number-of-arguments
    arg_parser.add_argument(
        '--display',
        action="store_true",
        help=
        'Specify for debugging one by one. Displays matplotlib plots. This will block program flow and require user interaction to close window to move on to next data sample.'
    )
    arg_parser.add_argument('--no-save',
                            action='store_true',
                            help='Specify to NOT save any visualized images.')
    arg_parser.add_argument(
        '--uinput',
        action='store_true',
        help='Specify to use keyboard interaction, useful for debugging.')
    arg_parser.add_argument(
        '--scale-heatmaps',
        action='store_true',
        help=
        'Specify the flag if it was specified to occlusion_test.cpp to generate the heatmaps. Rescale heatmaps back to min/max depth values.'
    )
    arg_parser.add_argument('--black-bg',
                            action='store_true',
                            help='Generate plots with black background.')

    args = arg_parser.parse_args()

    SCALE_HEATMAPS = args.scale_heatmaps
    BLACK_BG = args.black_bg

    DISPLAY_IMAGES = args.display
    UINPUT = args.uinput

    SAVE_IMG = not args.no_save
    # Number of grasps per object to save images
    N_TO_SAVE = 1

    # User adjust parameter. Save just a single axis, for paper or presentation
    # Indices in the 3 subplots
    DEPTH_IDX = 1
    VIS_IDX = 2
    OCC_IDX = 3
    # NOTE: Make sure these names match with the subplot indices!
    SUBPLOT_NAMES = [get_vis_depth_fmt(), get_vis_vis_fmt(), get_vis_occ_fmt()]
    #SAVE_SUBPLOTS = [DEPTH_IDX]
    SAVE_SUBPLOTS = [DEPTH_IDX, VIS_IDX, OCC_IDX]

    pkg_path = rospkg.RosPack().get_path('depth_scene_rendering')
    scene_list_path = os.path.join(pkg_path, "config/scenes_noisy.yaml")
    scene_list_f = open(scene_list_path, 'rb')

    renders_dir = get_renders_data_path()
    heatmaps_dir = get_heatmaps_data_path()

    heatmap_fmts = get_heatmap_blob_fmt()
    vis_fmt = heatmap_fmts[0]
    occ_fmt = heatmap_fmts[1]

    scaler = RawDepthScaling()
    success, depth_range = scaler.load_depth_range()
    if not success:
        return
    MIN_DEPTH = depth_range[0]
    MAX_DEPTH = depth_range[1]

    vis_dir = get_vis_path()
    contacts_dir = get_contacts_path()

    # scenes.yaml
    objs = ConfigReadYAML.read_scene_paths()
    # String
    obj_names = objs[0]
    # List of list of strings, paths to .pcd scene files
    scene_paths = objs[2]

    # For each object
    terminate = False
    for o_i in range(len(obj_names)):
        #for o_i in [7]:

        obj_name = obj_names[o_i]

        # Contacts meta file, number of elements in the list is number of grasps
        obj_meta_path = os.path.join(contacts_dir, obj_name + '_meta.csv')
        with open(obj_meta_path, 'rb') as obj_meta_f:
            obj_meta_reader = csv.reader(obj_meta_f)
            # Only 1 row in file. List of strings separated by comma
            for row in obj_meta_reader:
                n_grasps = len(row)

        skip_obj = False
        # For each scene for this object
        for s_i in range(len(scene_paths[o_i])):

            # For each grasp for this object
            for g_i in range(n_grasps):

                scene_path = scene_paths[o_i][s_i]
                scene_base = os.path.basename(scene_path)

                print(
                    'Object [%d], Scene [%d], Grasp [%d], Loading triplet files for %s'
                    % (o_i, s_i, g_i, scene_path))

                depth_name = os.path.join(
                    renders_dir,
                    os.path.splitext(scene_base)[0] + 'crop.png')
                vis_name = os.path.join(
                    heatmaps_dir,
                    vis_fmt % (os.path.splitext(scene_base)[0], g_i))
                occ_name = os.path.join(
                    heatmaps_dir,
                    occ_fmt % (os.path.splitext(scene_base)[0], g_i))

                depth_im = np_from_depth(depth_name)
                vis_im = np_from_depth(vis_name)
                occ_im = np_from_depth(occ_name)

                if depth_im is None or vis_im is None or occ_im is None:
                    print(
                        '%sERROR: One or more of (depth, vis_tac, occ_tac) images does not exist. Did you forget to run rosrun depth_scene_generation postprocess_scenes? Run it to generate the cropped images. Terminating...%s'
                        % (ansi.FAIL, ansi.ENDC))
                    return

                # Calculate raw depths from the integers in image
                depth_im = scaler.scale_ints_to_depths(depth_im)
                if SCALE_HEATMAPS:
                    vis_im = scaler.scale_ints_to_depths(vis_im)
                    occ_im = scaler.scale_ints_to_depths(occ_im)
                else:
                    vis_im = vis_im.astype(np.float32) / 255.0
                    occ_im = occ_im.astype(np.float32) / 255.0

                fig = plt.figure(figsize=(15, 6))

                ax = plt.subplot(1, 3, 1)
                # gray_r
                depth_obj = plt.imshow(depth_im[:, :, 0], cmap=plt.cm.jet)
                #clim=[MIN_DEPTH, MAX_DEPTH])
                tt1 = plt.title('Raw Depth')
                cb1 = plt.colorbar(depth_obj, fraction=0.046, pad=0.01)
                if BLACK_BG:
                    black_background(title_hdl=tt1)
                    black_colorbar(cb1)

                ax = plt.subplot(1, 3, 2)
                # Plot a white background first, so that black background plots produce
                #   the same heatmap appearances as white background ones.
                # Ref plot white image https://stackoverflow.com/questions/28234416/plotting-a-white-grayscale-image-in-python-matplotlib
                plt.imshow(np.ones((depth_im.shape[0], depth_im.shape[1])),
                           cmap='gray',
                           vmin=0,
                           vmax=1,
                           alpha=1.0)
                plt.imshow(depth_im[:, :, 0], cmap=plt.cm.jet, alpha=0.4)
                vis_obj = plt.imshow(vis_im[:, :, 0],
                                     cmap=plt.cm.jet,
                                     alpha=0.7)
                #clim=[MIN_DEPTH, MAX_DEPTH])
                tt2 = plt.title('Visible')
                cb2 = plt.colorbar(vis_obj, fraction=0.046, pad=0.01)
                if BLACK_BG:
                    black_background(title_hdl=tt2)
                    black_colorbar(cb2)

                ax = plt.subplot(1, 3, 3)
                plt.imshow(np.ones((depth_im.shape[0], depth_im.shape[1])),
                           cmap='gray',
                           vmin=0,
                           vmax=1,
                           alpha=1.0)
                plt.imshow(depth_im[:, :, 0], cmap=plt.cm.jet, alpha=0.4)
                occ_obj = plt.imshow(occ_im[:, :, 0],
                                     cmap=plt.cm.jet,
                                     alpha=0.7)
                #clim=[MIN_DEPTH, MAX_DEPTH])
                tt3 = plt.title('Occluded')
                # Flush colorbar with image
                cb3 = plt.colorbar(occ_obj, fraction=0.046, pad=0.01)
                if BLACK_BG:
                    black_background(title_hdl=tt3)
                    black_colorbar(cb3)

                fig.tight_layout()

                if SAVE_IMG:
                    dest = os.path.join(
                        vis_dir,
                        get_vis_heatmap_fmt() % (os.path.splitext(
                            os.path.basename(scene_base))[0], g_i))

                    if BLACK_BG:
                        plt.savefig(dest,
                                    bbox_inches='tight',
                                    facecolor=fig.get_facecolor(),
                                    edgecolor='none',
                                    transparent=True)
                    else:
                        fig.savefig(dest)
                    print('%sWritten entire plot to %s%s' %
                          (ansi.OKCYAN, dest, ansi.ENDC))

                # Save an individual axis
                for subplot_i in SAVE_SUBPLOTS:

                    # For depth image, only need to save 1st one, `.` all the same
                    if subplot_i == DEPTH_IDX and g_i != 0:
                        continue

                    ax = plt.subplot(1, 3, subplot_i)
                    ax.set_aspect(1)
                    plt.axis('off')

                    # To save individual axis cleanly
                    depth_dest = os.path.join(
                        vis_dir, SUBPLOT_NAMES[subplot_i - 1] %
                        (os.path.splitext(os.path.basename(scene_base))[0]))
                    extent = ax.get_window_extent().transformed(
                        fig.dpi_scale_trans.inverted())
                    fig.savefig(depth_dest, bbox_inches=extent)
                    print('%sWritten depth image axis to %s%s' %
                          (ansi.OKCYAN, depth_dest, ansi.ENDC))

                if DISPLAY_IMAGES:
                    plt.show()

                plt.close(fig)

                if g_i >= N_TO_SAVE - 1:
                    break

                if UINPUT:
                    uinput = raw_input(
                        'Press s to skip to next scene, o to skip to next objet, q to quit, or anything else to go to next grasp in this scene: '
                    )
                    if uinput.lower() == 's':
                        break
                    elif uinput.lower() == 'o':
                        skip_obj = True
                        break
                    elif uinput.lower() == 'q':
                        terminate = True
                        break

            # Break out of s_i
            if skip_obj or terminate:
                break

        # Break out of o_i
        if terminate:
            break
Esempio n. 4
0
def main():

    ORDERED = False
    TOP_ONLY = True

    BLACK_BG = True

    plot_all = False

    if ORDERED:
        pos, quats = test_ordered_poses(topOnly=TOP_ONLY)
        if TOP_ONLY:
            out_name = 'test_spherical_pose_generation_top'
        else:
            out_name = 'test_spherical_pose_generation_full'

    else:
        n_rand_pts = 100
        pos, quats = test_rand_poses(n_rand_pts, topOnly=TOP_ONLY)

        if TOP_ONLY:
            out_name = 'test_spherical_pose_generation_rand_top'
        else:
            out_name = 'test_spherical_pose_generation_rand_full'

    #print quats.shape
    n_pts = quats.shape[1]

    #####
    # Use quaternion rotation only, plot rotations on unit sphere

    # 3 x n
    XYZ = np.tile(np.array([[0, 0, 0]]).T, (1, n_pts))
    # 3 x n
    UVW = np.zeros((3, n_pts))

    for i in range(n_pts):

        mat = quaternion_matrix(quats[:, i])

        # Multiply the rotation by x-axis, to get a vector
        UVW[:, i] = np.dot(mat[0:3, 0:3], [1, 0, 0])

    if plot_all:

        fig = plt.figure(figsize=(15, 6))

        # Ref 3D plot https://matplotlib.org/mpl_toolkits/mplot3d/tutorial.html
        ax = fig.add_subplot(131, projection='3d')

        # Ref https://matplotlib.org/examples/mplot3d/quiver3d_demo.html
        # First three xyzs are start point, last three are vector.
        # pivot specifies the part of the arrow that is at the grid point, 1st xyzs
        #   passed in. Arrow rotates about this point. Default is 'tip', regardless
        #   of what API says. Arrowheads end at first set of xyzs passed in. Other
        #   option is 'tail' or 'middle'.
        ax.quiver(UVW[0, :],
                  UVW[1, :],
                  UVW[2, :],
                  -UVW[0, :],
                  -UVW[1, :],
                  -UVW[2, :],
                  color='orange',
                  length=0.4,
                  arrow_length_ratio=0.5,
                  pivot='tail')
        ax.scatter(UVW[0, :], UVW[1, :], UVW[2, :], c='orange')

        ax.set_title('By quaternion only')
        ax.set_aspect(1)
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.set_zlim(-1, 1)

        #####
        # Use position only, plot positions on sphere

        ax = fig.add_subplot(132, projection='3d')

        ax.scatter(pos[0, :], pos[1, :], pos[2, :], c='red')

        ax.set_title('By position only')
        ax.set_aspect(1)
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.set_zlim(-1, 1)

        #####
        # Overlap position and quaternion, to see if they overlap exactly
        # Result: They do NOT overlap exactly. So the best option to generate both
        #   position and quaternion is using the quaternion way, where quats variable
        #   gives the quaternions, UVW gives the positions.

        ax = fig.add_subplot(133, projection='3d')

        # Quaternion
        #ax.quiver (XYZ[0, :], XYZ[1, :], XYZ[2, :], UVW[0, :], UVW[1, :], UVW[2, :],
        #  color='orange', length=1, arrow_length_ratio=0.1, pivot='tail')
        ax.scatter(UVW[0, :], UVW[1, :], UVW[2, :], c='orange', alpha=0.5)

        # Position
        ax.scatter(pos[0, :], pos[1, :], pos[2, :], c='red', alpha=0.5)

        # Text label, for easier debugging of mismatches between positions and quats
        for i in range(pos.shape[1]):
            ax.text(UVW[0, i], UVW[1, i], UVW[2, i], str(i), color='orange')
            ax.text(pos[0, i], pos[1, i], pos[2, i], str(i), color='red')

        ax.set_title('Overlay quaternion and position')
        ax.set_aspect(1)
        ax.set_xlim(-1, 1)
        ax.set_ylim(-1, 1)
        ax.set_zlim(-1, 1)

        fig.tight_layout()
        '''
    # Save png
    fig.savefig (out_name + '.png')
    print ('Written plot to %s.png' % out_name)
 
    # Save eps
    fig.savefig (out_name + '.eps')
    print ('Written plot to %s.eps' % out_name)
    '''

        plt.show()

    # Plot an individual image for thesis writing

    fig2 = plt.figure()
    ax = fig2.add_subplot(111, projection='3d')

    cm_name = custom_colormap_neon()
    color = mpl_color(1, 8, colormap_name=cm_name)
    #color = np.array ((252, 149,  11)) / 255.0

    # Ref https://matplotlib.org/examples/mplot3d/quiver3d_demo.html
    # First three xyzs are start point, last three are vector.
    # pivot specifies the part of the arrow that is at the grid point, 1st xyzs
    #   passed in. Arrow rotates about this point. Default is 'tip', regardless
    #   of what API says. Arrowheads end at first set of xyzs passed in. Other
    #   option is 'tail' or 'middle'.
    ax.quiver(UVW[0, :],
              UVW[1, :],
              UVW[2, :],
              -UVW[0, :],
              -UVW[1, :],
              -UVW[2, :],
              color=color,
              length=0.4,
              arrow_length_ratio=0.5,
              pivot='tail')

    # 1, 4 go together well, dark blue, light green
    # 1, 3 okay too, dark blue, cyan
    # 1, 5 good too, dark blue, yellow
    color = mpl_color(4, 8, colormap_name=cm_name)
    ax.scatter(UVW[0, :], UVW[1, :], UVW[2, :], c=color)

    ax.set_title('Random poses')
    ax.set_aspect(1)
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    ax.set_zlim(-1, 1)

    fig2.tight_layout()
    if BLACK_BG:
        black_3d_background(ax)
        black_background(ax)

    single_out_base = out_name + '_single_black'

    single_out_name = single_out_base + '.eps'
    fig2.savefig(single_out_name,
                 bbox_inches='tight',
                 facecolor=fig2.get_facecolor(),
                 edgecolor='none',
                 transparent=True)
    print('Written plot to %s' % single_out_name)

    single_out_name = single_out_base + '.png'
    fig2.savefig(single_out_name,
                 bbox_inches='tight',
                 facecolor=fig2.get_facecolor(),
                 edgecolor='none',
                 transparent=True)
    print('Written plot to %s' % single_out_name)

    plt.show()
def plot_bars(obj_ids, obj_errs, obj_names_ordered, lbls_by_value, out_name):

    truetype()

    cm_name = custom_colormap_neon()

    # 2 (light sky blue) and 7 (noen orange) look good
    color = mpl_color(2, 8, colormap_name=cm_name)

    _, title_hdl = plot_line(obj_ids,
                             obj_errs,
                             'Per-Object-Class Errors',
                             'Object',
                             'Error',
                             out_name='',
                             color=color,
                             lbl='',
                             style='bar',
                             dots=False,
                             grid=True,
                             do_save=False,
                             do_show=False,
                             return_title_hdl=True)

    black_background(title_hdl=title_hdl)

    # Limit ticks to object IDs, not the extra columns on left and right
    mpl_diagonal_xticks(plt.gca(), obj_ids, obj_names_ordered, rot_degs=0)

    ## Plot a horizontal line across, at average per-class error

    # Dashed line
    # API https://matplotlib.org/api/_as_gen/matplotlib.axes.Axes.axhline.html
    mean_err = np.mean(obj_errs)
    plt.gca().axhline(y=mean_err,
                      linewidth=1,
                      color='w',
                      linestyle='--',
                      label='Mean per-class')
    mean_err_obj = plt.text(0, mean_err + 0.01, '%.3f' % mean_err, color='w')

    if len(lbls_by_value) > 0:
        smaller_proportion = np.min(lbls_by_value) / float(
            np.sum(lbls_by_value))
        plt.gca().axhline(y=smaller_proportion,
                          linewidth=1,
                          color='r',
                          linestyle='--',
                          label='Labels portion')
        plt.text(0,
                 smaller_proportion + 0.01,
                 '%.3f' % smaller_proportion,
                 color='r')

        # Positive is good, negative is bad
        mean_err_obj.set_text (mean_err_obj.get_text () + \
          ' (%.3f lower)' % (smaller_proportion - mean_err))

    legend_hdl = plt.legend()
    black_legend(legend_hdl)

    # y-axis is percentage
    plt.gca().set_ylim([0.0, 1.0])

    ## Plot

    plt.savefig(out_name,
                bbox_inches='tight',
                facecolor=plt.gcf().get_facecolor(),
                edgecolor='none',
                transparent=True)
    print('%sWritten plot to %s%s' % (ansi.OKCYAN, out_name, ansi.ENDC))

    plt.show()
Esempio n. 6
0
def draw_confusion_matrix_per_sample(unsorted_idx,
                                     distance_matrix,
                                     ticks=[],
                                     img_name='',
                                     title_prefix='',
                                     draw_title=True,
                                     draw_xylbls=True,
                                     draw_xticks=True,
                                     draw_yticks=True,
                                     fontsize=10,
                                     bg_color='white',
                                     cmap_name='jet_r'):

    # Matrix is n x n
    nSamples = np.shape(unsorted_idx)[0]

    dists_sorted = np.zeros(np.shape(distance_matrix))

    for i in range(0, nSamples):

        # Sort the ith row of sample IDs. Save the indexes from sorting.
        # Ref: http://docs.scipy.org/doc/numpy/reference/generated/numpy.argsort.html
        sorted_idx_rowi = np.argsort(unsorted_idx[i, :])

        # Take ith row of distance matrix, index it using the sorted indices
        dists_sorted[i, :] = distance_matrix[i, :][sorted_idx_rowi]

    print('')
    print ('Max distance: %f. Min distance: %f. Mean distance: %f' % \
      (np.max (dists_sorted), np.min (dists_sorted), np.mean (dists_sorted)))

    print('Confusion matrix:')
    print(dists_sorted)

    # Show confusion matrix in a separate window
    # Using a figure makes colorbar same length as main plot!
    fig = plt.figure()
    plt.matshow(dists_sorted, fignum=fig.number)
    if draw_title:
        if ticks:
            plt.title(title_prefix + 'Confusion matrix', y=1.3)
        else:
            plt.title(title_prefix + 'Confusion matrix')
    # Use jet_r. Reverse colors, so that small dists are hot, far dists are cold.
    #   http://stackoverflow.com/questions/3279560/invert-colormap-in-matplotlib
    # Save orig cmap so can set it back for future plots
    #orig_cmap = plt.get_cmap ()
    plt.set_cmap(cmap_name)
    colorbar = plt.colorbar()

    if draw_xylbls:
        plt.ylabel('Objects')
        plt.xlabel('Objects')

    ax = plt.gca()

    # Display class names on the axes ticks
    if draw_xticks or draw_yticks:
        # Custom ticks passed in
        if len(ticks) > 0:
            draw_classname_ticks(dists_sorted,
                                 ticks,
                                 draw_x=draw_xticks,
                                 draw_y=draw_yticks,
                                 fontsize=fontsize)

        # Automatically generate numerical ticks, every 20 units
        else:
            units_per_tick = 20.0

            if draw_xticks:
                # You can adjust these to make as frequent or as rare ticks as you want
                plt.xticks(np.arange(0, nSamples, units_per_tick))

            if draw_yticks:
                plt.yticks(np.arange(0, nSamples, units_per_tick))

            #print (nSamples)
            #print (np.arange (0, nSamples, units_per_tick))

    if not draw_xticks:
        # http://stackoverflow.com/questions/2176424/hiding-axis-text-in-matplotlib-plots
        #ax.get_xaxis ().set_visible (False)
        ax.get_xaxis().set_ticks([])

    if not draw_yticks:
        #ax.get_yaxis ().set_visible (False)
        ax.get_yaxis().set_ticks([])

    for tick in ax.xaxis.get_major_ticks():
        # http://stackoverflow.com/questions/6390393/matplotlib-make-tick-labels-font-size-smaller
        tick.label.set_fontsize(10)
        # specify integer or one of preset strings, e.g.
        #tick.label.set_fontsize('x-small')
        #tick.label.set_rotation('vertical')

    # Set background color
    if bg_color == 'black':
        black_background()
        black_colorbar(colorbar)

    if img_name:
        # Ref savefig() black background: https://stackoverflow.com/questions/4804005/matplotlib-figure-facecolor-background-color
        plt.savefig(img_name,
                    bbox_inches='tight',
                    facecolor=fig.get_facecolor(),
                    edgecolor='none',
                    transparent=True)
        print('Plot saved to %s' % img_name)

    plt.show()
Esempio n. 7
0
def draw_confusion_matrix(true_lbls,
                          predicted_lbls,
                          ticks=[],
                          img_name='',
                          title_prefix='',
                          draw_title=True,
                          raw=False,
                          bg_color='white'):

    # Compute confusion matrix
    # Make sure you keep order the same. In (true, predicted) order, the y-axis
    #   is true label, the x-axis is predicted. If you swap these, then the
    #   plot's x- and y-labels need to swap as well.
    # Ref: http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
    cm = confusion_matrix(true_lbls, predicted_lbls)
    print('Confusion matrix:')
    print(cm)

    # Do percentages, instead of raw numbers that are returned by default
    if not raw:
        # Sum each row to get total number of objects in each category (y-axis
        #   label of confusion mat is true labels, so sum rows, not columns).
        #   Need floating point so cm divided by this array gives floats, not ints!
        nTtl = np.sum(cm, axis=1).astype(np.float32)

        # Reshape to column vector, so element-wise divide would divide each row of
        #   conf mat by total object sum in that row.
        nTtl = nTtl.reshape((np.size(nTtl), 1))

        # Element-wise divide each row of confusion matrix by the total # objects
        cm = np.divide(cm, nTtl)
        #print (cm)

    # Show confusion matrix in a separate window
    # Using a figure makes colorbar same length as main plot!
    fig = plt.figure()
    plt.matshow(cm, fignum=fig.number)
    plt.set_cmap('jet')
    colorbar = plt.colorbar()
    if draw_title:
        # Ref move title up: http://stackoverflow.com/questions/12750355/python-matplotlib-figure-title-overlaps-axes-label-when-using-twiny
        plt.title(title_prefix + 'Confusion matrix', y=1.2)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

    # Display class names on the axes ticks
    if ticks:
        draw_classname_ticks(cm, ticks)

    # Set background color
    if bg_color == 'black':
        black_background()
        black_colorbar(colorbar)

    if img_name:
        # Ref savefig() black background: https://stackoverflow.com/questions/4804005/matplotlib-figure-facecolor-background-color
        plt.savefig(img_name,
                    bbox_inches='tight',
                    facecolor=fig.get_facecolor(),
                    edgecolor='none',
                    transparent=True)
        print('Plot saved to %s' % img_name)

    plt.show()
def save_individual_subplots(figs,
                             ylbls,
                             ylbl_to_save,
                             nRows,
                             nCols,
                             scale_xmin=1.1,
                             scale_x=1.2,
                             scale_y=1.15,
                             ndims=3,
                             show=True,
                             bg_color='white',
                             tick_rot=0):

    x_expand_idx = []
    y_expand_idx = []
    # Expand horizontal space for plots on left column
    for i in range(0, nRows):
        # Subplots on left have indices 0*n, 1*n, 2*n, ..., (n-1)*n. Add margin on
        #   x on left.
        x_expand_idx.append(i * nRows)

    # Expand vertical space for plots on botom row
    for i in range(0, nCols):
        # Subplots at bottom have indices (n-1)*n+0, (n-1)*n+1, (n-1)*n+2, ...,
        #   (n-1)*n+(n-1). Add margin on y at bottom.
        y_expand_idx.append((nRows - 1) * nCols + i)

    # Constants for scaling bbox, tuned for ticks and labels with fontsize 6
    SCALE_XMIN = scale_xmin
    SCALE_YMIN = 1.1
    SCALE_X = scale_x
    SCALE_Y = scale_y

    # Find the figure that matches the desired ylbl
    for f_i in range(0, len(ylbls)):

        if ylbls[f_i] != ylbl_to_save:
            continue

        # Found the figure with desired ylbl
        else:
            imgpath = get_img_path('hists')
            imgname = os.path.join(
                imgpath, 'triangle_' + str(ndims) + 'Dhist_' + ylbls[f_i])

            axes = figs[f_i].get_axes()

            for sp_i in range(0, len(axes)):
                curr_imgname = imgname + ('_%02d' % sp_i) + '.eps'

                # Ref save a subplot (axis) in a figure to file:
                #   http://stackoverflow.com/questions/4325733/save-a-subplot-in-matplotlib
                extent = axes[sp_i].get_window_extent().transformed(
                    figs[f_i].dpi_scale_trans.inverted())

                # Default. Need 1.05 to show entire plot, 1 doesn't show whole thing!
                #   1.3 shows axes labels (font size 6) perfectly, for x or y
                bbox_inches = extent.expanded(SCALE_XMIN, SCALE_YMIN)
                # Record original x y position of box, before expanding
                #   http://matplotlib.org/devel/transformations.html
                xmax_orig = bbox_inches.xmax
                ymax_orig = bbox_inches.ymax

                # expanded() works outwards from center only, can't make box fixed
                #   in one corner. So have to manually move box back after expand.
                # extent is a matplotlib.transforms.BboxBase type
                #   Ref http://matplotlib.org/devel/transformations.html
                # Subplot on lower-left needs to expand in both x and y
                if sp_i in x_expand_idx and sp_i in y_expand_idx:
                    bbox_inches = extent.expanded(SCALE_X, SCALE_Y)
                # Only expand x, on left
                elif sp_i in x_expand_idx:
                    # x should be same as above, so last row is same width as prev rows
                    bbox_inches = extent.expanded(SCALE_X, SCALE_YMIN)
                # Only expand y, on bottom
                elif sp_i in y_expand_idx:
                    # Need extra y to show bottom x label, need extra x to show left-most
                    #   tick.
                    bbox_inches = extent.expanded(SCALE_XMIN, SCALE_Y)

                # Shift content in bbox to sit at upper-right of the expanded bbox
                xmax_new = bbox_inches.xmax
                ymax_new = bbox_inches.ymax
                bbox_inches = bbox_inches.translated(xmax_orig - xmax_new,
                                                     ymax_orig - ymax_new)

                if bg_color == 'black':
                    black_background(axes[sp_i])

                # Ref savefig() black background: https://stackoverflow.com/questions/4804005/matplotlib-figure-facecolor-background-color
                figs[f_i].savefig(curr_imgname,
                                  bbox_inches=bbox_inches,
                                  facecolor=bg_color,
                                  edgecolor='none',
                                  transparent=True)

                print ('Individual axes %d in dimension %s saved to %s' % ( \
                  sp_i, ylbl_to_save, curr_imgname))

            break