def EstablishAxes(fig, args):
  """ Create a single axis on the figure object.

  Args:
    fig: a matplotlib figure object
    args: an argparse arguments object

  Returns:
    ax: a matplotlib axis object
  Raises:
    ValueError: If an unknown spine location is passed.
  """
  # left 0.99 inches, right 0.54 inches, width 7.47 inches
  # bottom 0.68 inches, top 0.28 inches, height 3.04 inches
  args.axLeft = 1.1 / args.width
  args.axRight = 1.0 - (1.3 / args.width)
  args.axWidth = args.axRight - args.axLeft
  args.axBottom = 0.9 / args.height
  args.axTop = 1.0 - (0.4 / args.height)
  args.axHeight = args.axTop - args.axBottom
  ax = fig.add_axes([args.axLeft, args.axBottom,
                     args.axWidth, args.axHeight])
  ax.yaxis.set_major_locator(pylab.NullLocator())
  ax.xaxis.set_major_locator(pylab.NullLocator())
  for loc, spine in ax.spines.iteritems():
    if loc in ['left', 'bottom']:
      spine.set_position(('outward', 10))
    elif loc in ['right', 'top']:
      spine.set_color('none')
    else:
      raise ValueError('unknown spine location: %s' % loc)
  ax.xaxis.set_ticks_position('bottom')
  ax.yaxis.set_ticks_position('left')
  return ax
Exemplo n.º 2
0
def establishAxes(fig, categories, options, data):
    axDict = {}
    data.backgroundAx = fig.add_axes([0.0, 0.0, 1.0, 1.0])
    data.backgroundAx.yaxis.set_major_locator(pylab.NullLocator())
    data.backgroundAx.xaxis.set_major_locator(pylab.NullLocator())
    plt.box(on=False)
    options.axLeft = 0.01
    options.axRight = 0.99
    options.width = options.axRight - options.axLeft
    options.axBottom = 0.1
    options.axTop = 0.85
    options.axHeight = options.axTop - options.axBottom
    margin = 0.017
    width = (options.width - (len(categories) - 1) * margin) / len(categories)
    xpos = options.axLeft
    sortedOrder = categories.keys()
    sortedOrder.sort()
    options.axDictSortedOrder = sortedOrder
    for c in sortedOrder:
        axDict[c] = fig.add_axes(
            [xpos, options.axBottom, width, options.axHeight])
        axDict[c].yaxis.set_major_locator(pylab.NullLocator())
        axDict[c].xaxis.set_major_locator(pylab.NullLocator())
        xpos += width + margin
        plt.box(on=False)
    data.axDict = axDict
    return (axDict)
Exemplo n.º 3
0
def establish_axes(fig, width, height, border=True, has_legend=True):
    """
    Sets up axes. No idea how this works, Dent's code.
    """
    ax_left = 1.1 / width
    if border is True:
        if has_legend is True:
            ax_right = 1.0 - (1.8 / width)
        else:
            ax_right = 1.0 - (1.15 / width)
    else:
        if has_legend is True:
            ax_right = 1.1 - (1.8 / width)
        else:
            ax_right = 1.1 - (1.15 / width)
    ax_width = ax_right - ax_left
    ax_bottom = 1.4 / height
    ax_top = 0.90 - (0.4 / height)
    ax_height = ax_top - ax_bottom
    ax = fig.add_axes([ax_left, ax_bottom, ax_width, ax_height])
    ax.yaxis.set_major_locator(pylab.NullLocator())
    ax.xaxis.set_major_locator(pylab.NullLocator())
    for loc, spine in ax.spines.iteritems():
        if loc in ['left', 'bottom']:
            spine.set_position(('outward', 10))
        elif loc in ['right', 'top']:
            spine.set_color('none')
        else:
            raise ValueError('unknown spine location: %s' % loc)
    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')
    return ax
Exemplo n.º 4
0
def hideaxis(pos=None):
    # hide x y axis
    if pos:
        df = pd.DataFrame(pos.values(), columns=['x', 'y'])
        plt.xlim([df['x'].min() - 5, df['x'].max() + 5])
        plt.ylim([df['y'].min() - 5, df['y'].max() + 5])
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
Exemplo n.º 5
0
def setAxisLimits( axDict, options, data ):
   if ( options.stackFillBlocks or options.stackFillContigPaths or 
        options.stackFillContigs or options.stackFillScaffPaths ):
      data.footerAx.set_ylim( 0.0, 1.01 )
      data.footerAx.set_xlim( 0.0, 1.0 )
      data.footerAx.xaxis.set_major_locator( pylab.NullLocator() )
      data.footerAx.yaxis.set_major_locator( pylab.NullLocator() )
   for c in data.chrNames:
      i = -1
      for a in data.annotationOrder:
         i += 1
         axDict[ c + a ].set_ylim( 0.0, data.annotationCeilings[ i ] )
         axDict[ c + a ].set_xlim( 0.0, data.chrLengthsByChrom[ c ] )
         axDict[ c + a ].xaxis.set_major_locator( pylab.NullLocator() )
         axDict[ c + a ].yaxis.set_major_locator( pylab.NullLocator() )
      for n in data.orderedMafs:
         if c + n not in axDict:
            continue
         axDict[ c + n ].set_ylim( 0.0, data.axCeilings[ n ] )
         axDict[ c + n ].set_xlim( 0.0, data.chrLengthsByChrom[ c ] )
         axDict[ c + n ].xaxis.set_major_locator( pylab.NullLocator() )
         axDict[ c + n ].yaxis.set_major_locator( pylab.NullLocator() )
      axDict[ c ].set_ylim( 0.0, 1.01 )
      axDict[ c ].set_xlim( 0.0, data.chrLengthsByChrom[ c ] )
      axDict[ c ].xaxis.set_major_locator( pylab.NullLocator() )
      axDict[ c ].yaxis.set_major_locator( pylab.NullLocator() )
Exemplo n.º 6
0
def establishAxis(options, data):
    options.axLeft = 0.05
    options.axWidth = 0.9
    options.axTop = 0.95
    options.axHeight = 0.9
    options.axRight = options.axLeft + options.axWidth
    options.axBottom = options.axTop - options.axHeight
    options.margins = 0.015
    data.ax = data.fig.add_axes(
        [options.axLeft, options.axBottom, options.axWidth, options.axHeight])
    data.ax.yaxis.set_major_locator(pylab.NullLocator())
    data.ax.xaxis.set_major_locator(pylab.NullLocator())
    if not options.frames:
        plt.box(on=False)
Exemplo n.º 7
0
def plot_true_alt2(p, q):
    RT0 = matplotlib.patches.Rectangle((0., .6), (1 - p)**2,
                                       .1,
                                       facecolor='blue')
    RT1 = matplotlib.patches.Rectangle(((1 - p)**2, .6),
                                       2 * p * (1 - p),
                                       .1,
                                       facecolor="green")
    RT2 = matplotlib.patches.Rectangle(((1 - p)**2 + 2 * p * (1 - p), .6),
                                       p**2,
                                       .1,
                                       facecolor='orange')
    RA0 = matplotlib.patches.Rectangle((0., .2), (1 - q)**2,
                                       .1,
                                       facecolor='blue')
    RA1 = matplotlib.patches.Rectangle(((1 - q)**2, .2),
                                       2 * q * (1 - q),
                                       .1,
                                       facecolor="green")
    RA2 = matplotlib.patches.Rectangle(((1 - q)**2 + 2 * q * (1 - q), .2),
                                       q**2,
                                       .1,
                                       facecolor='orange')
    Ax = plt.gca()
    Ax.add_patch(RT0)
    Ax.add_patch(RT1)
    Ax.add_patch(RT2)
    Ax.add_patch(RA0)
    Ax.add_patch(RA1)
    Ax.add_patch(RA2)
    Ax.yaxis.set_major_locator(plt.NullLocator())
    plt.xlabel("Probability")
    Ax.text(.05, .725, "True Model (2 Reps): p=" + str(p), fontsize=18)
    Ax.text(.05, .325, "Alternate Model (2 Reps): p=" + str(q), fontsize=18)
    matplotlib.rcParams.update({'font.size': 18})
Exemplo n.º 8
0
def plot_nchannels(p):
    RA0 = matplotlib.patches.Rectangle((0., .6), (1 - p), .1, facecolor='blue')
    RA1 = matplotlib.patches.Rectangle(((1 - p), .6), p, .1, facecolor="green")
    RT0 = matplotlib.patches.Rectangle((0., .2), (1 - p)**2,
                                       .1,
                                       facecolor='blue')
    RT1 = matplotlib.patches.Rectangle(((1 - p)**2, .2),
                                       2 * p * (1 - p),
                                       .1,
                                       facecolor="green")
    RT2 = matplotlib.patches.Rectangle(((1 - p)**2 + 2 * p * (1 - p), .2),
                                       p**2,
                                       .1,
                                       facecolor='orange')
    Ax = plt.gca()
    Ax.add_patch(RT0)
    Ax.add_patch(RT1)
    Ax.add_patch(RT2)
    Ax.add_patch(RA0)
    Ax.add_patch(RA1)
    Ax.yaxis.set_major_locator(plt.NullLocator())
    plt.xlabel("Probability")
    Ax.text(.05, .725, "One Channel: p=" + str(p), fontsize=18)
    Ax.text(.05, .325, "Two Channels: p=" + str(p), fontsize=18)
    matplotlib.rcParams.update({'font.size': 18})
def save_figure(date, folder):
    global g_imagecount
    directory = "{}/{}".format(folder, date)
    if not os.path.exists(directory):
        os.makedirs(directory)
    filename = "{}/{}/{}.png".format(folder, date, str(g_imagecount).zfill(5))
    extent = plt.gca().get_window_extent().transformed(
        plt.gcf().dpi_scale_trans.inverted())

    plt.gca().set_axis_off()
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.savefig(filename, bbox_inches='tight', pad_inches=0)
    g_imagecount += 1
Exemplo n.º 10
0
def xgetps10arcmin(ra, dec, size1, otname):
    ps1img = "%s_ps1_0.jpg" % (otname)
    # grayscale image
    #gim = getgrayim(ra,dec,size=size1,filter="i")
    # color image
    cim = getcolorim(ra, dec, size=size1, filters="grz")
    #r image
    #cim = getgrayim(ra,dec, size=size1, filter="r")
    #print(dir(cim))
    cim.save(ps1img)

    if os.access(ps1img, os.F_OK):

        plt.figure(figsize=(4, 4), dpi=50)
        #set size square
        img_arr = plt.imread(ps1img)
        plt.imshow(img_arr)
        plt.xticks([])
        plt.yticks([])
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.subplots_adjust(top=1,
                            bottom=0,
                            left=0,
                            right=1,
                            hspace=0,
                            wspace=0)
        plt.margins(0, 0)

        x = 1200
        y = 1200
        plt.scatter(x, y, marker="o", c='', edgecolors='w', s=1000)
        textc = "10*10 arcmin"

        RA_Hour = ra2hour(ra)
        DEC_Hour = dec2hour(dec)

        radec = "RA=%s, DEC=%s" % (RA_Hour, DEC_Hour)
        plt.text(800, 200, otname, color="w")
        plt.text(400, 300, radec, color="w")
        plt.text(1600, 2200, textc, color="w")
        #plt.show()
        pngfilename = "%s_ps1.png" % (otname)
        plt.savefig(pngfilename, dpi=50)
        return pngfilename
    else:
        print("no ps1 image ")
Exemplo n.º 11
0
def createAxes(left, top, width, height, options, data):
    # transform coordinates
    figLeft = options.axLeft + left * options.axWidth
    figTop = options.axBottom + options.axHeight * top
    figWidth = width * options.axWidth
    figHeight = height * options.axHeight
    figBottom = figTop - figHeight
    if options.mode != 'contigPaths':
        axMain = data.fig.add_axes(
            [figLeft, figBottom, figWidth, figHeight * 0.65])
        axMain.yaxis.set_major_locator(pylab.NullLocator())
        axMain.xaxis.set_major_locator(pylab.NullLocator())
        if not options.frames:
            plt.box(on=False)
        axCrazy = data.fig.add_axes([
            figLeft, figBottom + figHeight * 0.68, figWidth, figHeight * 0.04
        ])
        axCrazy.yaxis.set_major_locator(pylab.NullLocator())
        axCrazy.xaxis.set_major_locator(pylab.NullLocator())
        if not options.frames:
            plt.box(on=False)
        axBlowUp = data.fig.add_axes([
            figLeft, figBottom + figHeight * 0.75, figWidth, figHeight * 0.25
        ])
        axBlowUp.yaxis.set_major_locator(pylab.NullLocator())
        axBlowUp.xaxis.set_major_locator(pylab.NullLocator())
        if not options.frames:
            plt.box(on=False)
    else:
        axMain = data.fig.add_axes(
            [figLeft, figBottom, figWidth, figHeight * 0.72])
        axMain.yaxis.set_major_locator(pylab.NullLocator())
        axMain.xaxis.set_major_locator(pylab.NullLocator())
        if not options.frames:
            plt.box(on=False)
        axCrazy = None
        axBlowUp = data.fig.add_axes([
            figLeft, figBottom + figHeight * 0.75, figWidth, figHeight * 0.25
        ])
        axBlowUp.yaxis.set_major_locator(pylab.NullLocator())
        axBlowUp.xaxis.set_major_locator(pylab.NullLocator())
        if not options.frames:
            plt.box(on=False)
    return (axMain, axCrazy, axBlowUp)
Exemplo n.º 12
0
def hinton(matrix, max_weight=None, ax=None):
    """Draw Hinton diagram for visualizing a weight matrix."""
    ax = ax if ax is not None else plt.gca()

    if not max_weight:
        max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2))

    ax.patch.set_facecolor('gray')
    ax.set_aspect('equal', 'box')
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in np.ndenumerate(matrix):
        color = 'white' if w > 0 else 'black'
        size = np.sqrt(np.abs(w) / max_weight)
        rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
                             facecolor=color, edgecolor=color)
        ax.add_patch(rect)

    ax.autoscale_view()
    ax.invert_yaxis()
Exemplo n.º 13
0
def setAxisLimits(axMain, axCrazy, axBlowUp, xData, options, data):
    axMain.set_xscale('log')
    axMain.set_xlim(1, xData[-1])
    axMain.set_ylim(0.0, 1.0)
    #if options.SMM:
    #   axDict[ 'main' ].yaxis.set_major_locator( pylab.NullLocator() )
    if options.mode in set(
        ['blocks', 'contigs', 'contigPaths', 'scaffPaths', 'scaffolds']):
        if options.mode != 'contigPaths' and options.mode != 'scaffPaths':
            axCrazy.set_ylim(0.0, 1.02)
            axCrazy.set_xscale('log')
            axCrazy.set_xlim(1, xData[-1])
            axCrazy.xaxis.set_ticklabels([])

    axBlowUp.set_xscale('log')
    axBlowUp.set_xlim(1, xData[-1])
    axBlowUp.set_ylim(0.9, 1.0)
    axBlowUp.xaxis.set_ticklabels([])

    # turn off ticks
    for ax in [axMain, axCrazy, axBlowUp]:
        if not ax is None:
            ax.xaxis.set_ticks_position('bottom')
            ax.yaxis.set_ticks_position('left')

    if options.SMM:
        if options.mode != 'contigPaths':
            axCrazy.yaxis.set_major_locator(pylab.NullLocator())
            axCrazy.xaxis.set_major_locator(pylab.NullLocator())
            axCrazy.xaxis.set_minor_locator(pylab.NullLocator())
        axMain.xaxis.set_major_locator(pylab.NullLocator())
        axMain.xaxis.set_minor_locator(pylab.NullLocator())
        axBlowUp.xaxis.set_minor_locator(pylab.NullLocator())
Exemplo n.º 14
0
def plot_simil_mat_with_labels(simil_mat, y, inner_class_ordering='mean_shift_clusters', brightness=1.0,
                               figsize=(10, 10)):
    """
    A function that plots similarity matrices, grouping labels together and sorting by descending sum of similarities
    within a group.
    """
    simil_mat = simil_mat ** (1 / float(brightness))
    d = pd.DataFrame(simil_mat)
    d['y'] = y

    if inner_class_ordering == 'sum_simil':
        d['order'] = d.sum(axis=1)
    elif inner_class_ordering == 'mean_shift_clusters':
        d['order'] = np.nan
        for y_val in np.unique(y):
            lidx = y == y_val
            clus = MeanShift().fit(simil_mat[lidx][:, lidx])
            d['order'].iloc[lidx] = clus.labels_
    else:
        raise ValueError("Unknown inner_class_ordering")

    d = d.sort(['y', 'order'], ascending=False)
    y_vals = d['y']
    d = d.drop(labels=['y', 'order'], axis=1)

    permi = d.index.values
    w = simil_mat[permi][:, permi]

    plt.figure(figsize=figsize);
    ax = plt.gca();
    ax.matshow(w, cmap='gray_r');
    ax.grid(b=False)
    ax.set_aspect('equal', 'box');
    mids = list()
    unik_y_vals = np.unique(y_vals)
    for y_val in unik_y_vals:
        idx = np.where(y_vals == y_val)[0]
        pt = idx[0] - 0.5
        s = idx[-1] - idx[0] + 1
        mids.append(pt + s / 2)
        ax.add_patch(
            patches.Rectangle(xy=(pt, pt), width=s, height=s, fill=False, linewidth=2, color='blue', alpha=0.5));
    # plt.setp(ax.get_xticklabels(), visible=False);
    ax.xaxis.set_major_locator(plt.NullLocator())
    _ = ax.set_yticks(list(mids));
    _ = ax.set_yticklabels(unik_y_vals)
    _ = ax.set_xticks(list(mids));
    _ = ax.set_xticklabels(unik_y_vals, rotation=90)

    return y_vals.as_matrix()
Exemplo n.º 15
0
def establishAxes(fig, options, data):
    axDict = {}
    options.axLeft = 0.09
    options.axRight = 0.97
    options.axWidth = options.axRight - options.axLeft
    options.axBottom = 0.08
    options.axTop = 0.96
    options.axHeight = options.axTop - options.axBottom
    margin = 0.11
    facetHeight = (options.axHeight - 2.0 * margin) / 3.0
    yPos = 0.0
    for ax in ['def', 'exc', 'sum']:
        axDict[ax] = fig.add_axes([
            options.axLeft, options.axBottom + yPos, options.axWidth,
            facetHeight
        ])
        axDict[ax].yaxis.set_major_locator(pylab.NullLocator())
        axDict[ax].xaxis.set_major_locator(pylab.NullLocator())
        yPos += facetHeight + margin
        #plt.box( on=False )
    for ax in axDict:
        for loc, spine in axDict[ax].spines.iteritems():
            if loc in ['left', 'bottom']:
                spine.set_position(('outward', 10))  # outward by 10 points
            elif loc in ['right', 'top']:
                spine.set_color('none')  # don't draw spine
            else:
                raise ValueError('unknown spine location: %s' % loc)
        # turn off ticks where there is no spine
        axDict[ax].xaxis.set_ticks_position('bottom')
        if options.log:
            axDict[ax].yaxis.set_ticks_position('both')
        else:
            axDict[ax].yaxis.set_ticks_position('left')
    data.axDict = axDict
    return (axDict)
Exemplo n.º 16
0
def plot_true_alt(p, q):
    RTC = matplotlib.patches.Rectangle((0., .6), 1 - p, .1, facecolor='blue')
    RTO = matplotlib.patches.Rectangle((1 - p, .6), p, .1, facecolor="orange")
    RAC = matplotlib.patches.Rectangle((0., .2), 1 - q, .1, facecolor="blue")
    RAO = matplotlib.patches.Rectangle((1 - q, .2), q, .1, facecolor="orange")
    Ax = plt.gca()
    Ax.add_patch(RTC)
    Ax.add_patch(RTO)
    Ax.add_patch(RAC)
    Ax.add_patch(RAO)
    Ax.yaxis.set_major_locator(plt.NullLocator())
    plt.xlabel("Probability")
    Ax.text(.05, .725, "True Model: p=" + str(p), fontsize=18)
    Ax.text(.05, .325, "Alternate Model: p=" + str(q), fontsize=18)
    matplotlib.rcParams.update({'font.size': 18})
Exemplo n.º 17
0
def plot_heatmap_shap(attributions,  list_images, img_input, blend_original_image):

    pred_class_num = len(attributions[0])

    if blend_original_image:
        from LIBS.ImgPreprocess.my_image_norm import input_norm_reverse
        img_original = np.uint8(input_norm_reverse(img_input[0]))
        import cv2
        img_original = cv2.resize(img_original, (384, 384))
        img_original_file = os.path.join(os.path.dirname(list_images[0]), 'deepshap_original.jpg')
        cv2.imwrite(img_original_file, img_original)

    for i in range(pred_class_num):
        # predict_max_class = attributions[1][0][i]
        attribution1 = attributions[0][i]

        #attributions.shape: (1, 299, 299, 3)
        data = attribution1[0]
        data = np.mean(data, -1)

        abs_max = np.percentile(np.abs(data), 100)
        abs_min = abs_max

        # dx, dy = 0.05, 0.05
        # xx = np.arange(0.0, data1.shape[1], dx)
        # yy = np.arange(0.0, data1.shape[0], dy)
        # xmin, xmax, ymin, ymax = np.amin(xx), np.amax(xx), np.amin(yy), np.amax(yy)
        # extent = xmin, xmax, ymin, ymax

        # cmap = 'RdBu_r'
        # cmap = 'gray'
        cmap = 'seismic'
        plt.axis('off')
        # plt.imshow(data1, extent=extent, interpolation='none', cmap=cmap, vmin=-abs_min, vmax=abs_max)
        # plt.imshow(data, interpolation='none', cmap=cmap, vmin=-abs_min, vmax=abs_max)

        # fig = plt.gcf()
        # fig.set_size_inches(2.99 / 3, 2.99 / 3)  # dpi = 300, output = 700*700 pixels
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
        plt.margins(0, 0)

        if blend_original_image:
            # cv2.imwrite('/tmp5/tmp/cv2.jpg', np.uint8(img_input[0]))
            # img_original = cv2.cvtColor(np.uint8(img_input[0]), cv2.COLOR_BGR2RGB)
            # plt.imshow(img_original)

            plt.imshow(data, interpolation='none', cmap=cmap, vmin=-abs_min, vmax=abs_max)
            save_filename1 = list_images[i]
            plt.savefig(save_filename1, bbox_inches='tight', pad_inches=0)
            plt.close()

            img_heatmap = cv2.imread(list_images[i])
            img_heatmap = cv2.resize(img_heatmap, (384, 384))
            img_heatmap_file = os.path.join(os.path.dirname(list_images[i]), 'deepshap_{0}.jpg'.format(i))
            cv2.imwrite(img_heatmap_file, img_heatmap)

            dst = cv2.addWeighted(img_original, 0.65, img_heatmap, 0.35, 0)
            # cv2.imwrite('/tmp5/tmp/aaaaa.jpg', dst) #test code
            img_blend_file = os.path.join(os.path.dirname(list_images[i]), 'deepshap_blend_{0}.jpg'.format(i))
            cv2.imwrite(img_blend_file, dst)

            # fig.savefig('/tmp5/tmp/aaa1.png', format='png', dpi=299,  transparent=True,  pad_inches=0)
            # plt.savefig('/tmp5/tmp/aaa.jpg', bbox_inches='tight', pad_inches=0)

            #region create gif
            import imageio
            mg_paths = [img_original_file, img_heatmap_file, img_blend_file]
            gif_images = []
            for path in mg_paths:
                gif_images.append(imageio.imread(path))
            img_file_gif = os.path.join(os.path.dirname(list_images[i]), 'deepshap_{0}.gif'.format(i))
            imageio.mimsave(img_file_gif, gif_images, fps=GIF_FPS)
            list_images[i] = img_file_gif
            #endregion
        else:
            plt.imshow(data, interpolation='none', cmap=cmap, vmin=-abs_min, vmax=abs_max)
            save_filename1 = list_images[i]
            plt.savefig(save_filename1, bbox_inches='tight', pad_inches=0)
            plt.close()
Exemplo n.º 18
0
vmaxa = np.max(np.abs(cell1))
ampkw = {"cmap": plt.get_cmap("gray"), "vmin": vmina, "vmax": vmaxa}

# phase range
cell1p = unwrap.unwrap(np.angle(cell1))
cell2p = unwrap.unwrap(np.angle(cell2))
cell3p = unwrap.unwrap(np.angle(cell3))
vminp = np.min(cell1p)
vmaxp = np.max(cell1p)
phakw = {"cmap": plt.get_cmap("coolwarm"), "vmin": vminp, "vmax": vmaxp}

# plots
fig, axes = plt.subplots(2, 3, figsize=(8, 4.5))
axes = axes.flatten()
for ax in axes:
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())

# titles
axes[0].set_title("focused backward")
axes[1].set_title("original image")
axes[2].set_title("focused forward")

# data
mapamp = axes[0].imshow(np.abs(cell3), **ampkw)
axes[1].imshow(np.abs(cell1), **ampkw)
axes[2].imshow(np.abs(cell2), **ampkw)
mappha = axes[3].imshow(cell3p, **phakw)
axes[4].imshow(cell1p, **phakw)
axes[5].imshow(cell2p, **phakw)
Exemplo n.º 19
0
def AXIS(axis,
         xlabel=None,
         ylabel=None,
         remove_xticks=False,
         remove_yticks=False,
         remove_ticks_all=False,
         tickscolor="k"):
    plt.setp(axis.get_yticklabels(), rotation='vertical',
             fontsize=10)  #,visible=False)
    plt.setp(axis.get_xticklabels(), fontsize=10)

    axis.spines['bottom'].set_color(tickscolor)
    axis.spines['top'].set_color(tickscolor)
    axis.spines['left'].set_color(tickscolor)
    axis.spines['right'].set_color(tickscolor)

    axis.minorticks_on()
    axis.tick_params('both',
                     length=6.5,
                     width=0.7,
                     which='major',
                     direction='in',
                     color=tickscolor,
                     bottom=1,
                     top=1,
                     left=1,
                     right=1)
    axis.tick_params('both',
                     length=3.5,
                     width=0.7,
                     which='minor',
                     direction='in',
                     color=tickscolor,
                     bottom=1,
                     top=1,
                     left=1,
                     right=1)

    axis.tick_params(axis='x', colors='k', pad=1)
    axis.tick_params(axis='y', colors='k', pad=1)
    axis.tick_params(axis='both', direction='in', color=tickscolor)

    axis.xaxis.major.locator.set_params(nbins=3)
    axis.yaxis.major.locator.set_params(nbins=3)

    if xlabel != None:
        axis.set_xlabel('%s' % ylabel, fontsize=10)

    if ylabel != None:
        axis.set_xlabel('%s' % xlabel, fontsize=10)

    #
    # Remove x,y label ticks
    #
    if remove_xticks == True:
        axis.xaxis.set_major_formatter(plt.NullFormatter())

    if remove_yticks == True:
        axis.yaxis.set_major_formatter(plt.NullFormatter())

    # remove all from the axis (both ticks and xy ticks labels)

    if remove_ticks_all == True:
        axis.yaxis.set_major_locator(plt.NullLocator())
        axis.xaxis.set_major_locator(plt.NullLocator())

    axis.set_facecolor('#e8ebf2')
import pandas as pd
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument('--label', type=str)
parser.add_argument('--num', type=str)

args = parser.parse_args()

temp = np.load('temp.npy', allow_pickle=True)
a, b, c, d = temp
a, b, c = np.array(a, dtype=np.float), np.array(b, dtype=np.float), np.array(
    c, dtype=np.float)

plt.subplot(2, 1, 1)
plt.fill_between(np.arange(400), a, b, color='r')
plt.fill_between(np.arange(400), c, b, color='b')
plt.axis('off')

plt.subplot(2, 1, 2)
plt.bar(np.arange(400), d, color='g')
plt.axis('off')

plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
plt.margins(0, 0)
plt.axis('off')

plt.savefig('img\{}_{}.jpg'.format(args.num, args.label))
Exemplo n.º 21
0
    def shap_deep_explainer(self,
                            model_no,
                            num_reference,
                            img_input,
                            ranked_outputs=1,
                            norm_reverse=True,
                            blend_original_image=False,
                            gif_fps=1,
                            base_dir_save='/tmp/DeepExplain',
                            check_additivity=False):

        #region mini-batch because of GPU memory limitation
        list_shap_values = []

        batch_size = self.dicts_models[model_no]['batch_size']
        split_times = math.ceil(num_reference / batch_size)
        for i in range(split_times):
            shap_values_tmp1 = self.list_e[model_no][i].shap_values(
                img_input,
                ranked_outputs=ranked_outputs,
                check_additivity=check_additivity)
            # shap_values ranked_outputs
            # [0] [0] (1,299,299,3)
            # [1] predict_class array
            shap_values_copy = copy.deepcopy(shap_values_tmp1)
            list_shap_values.append(shap_values_copy)

        for i in range(ranked_outputs):
            for j in range(len(list_shap_values)):
                if j == 0:
                    shap_values_tmp2 = list_shap_values[0][0][i]
                else:
                    shap_values_tmp2 += list_shap_values[j][0][i]

            shap_values_results = copy.deepcopy(list_shap_values[0])
            shap_values_results[0][i] = shap_values_tmp2 / split_times

        #endregion

        #region save files
        str_uuid = str(uuid.uuid1())
        list_classes = []
        list_images = []
        for i in range(ranked_outputs):
            predict_class = int(
                shap_values_results[1][0][i])  #numpy int 64 - int
            list_classes.append(predict_class)

            save_filename = os.path.join(
                base_dir_save, str_uuid,
                'Shap_Deep_Explainer{}.jpg'.format(predict_class))
            os.makedirs(os.path.dirname(save_filename), exist_ok=True)
            list_images.append(save_filename)

        pred_class_num = len(shap_values_results[0])

        if blend_original_image:
            if norm_reverse:
                img_original = np.uint8(input_norm_reverse(img_input[0]))
            else:
                img_original = np.uint8(img_input[0])
            img_original_file = os.path.join(os.path.dirname(list_images[0]),
                                             'deepshap_original.jpg')
            cv2.imwrite(img_original_file, img_original)

        for i in range(pred_class_num):
            # predict_max_class = attributions[1][0][i]
            attribution1 = shap_values_results[0][i]

            #attributions.shape: (1, 299, 299, 3)
            data = attribution1[0]
            data = np.mean(data, -1)

            abs_max = np.percentile(np.abs(data), 100)
            abs_min = abs_max

            # dx, dy = 0.05, 0.05
            # xx = np.arange(0.0, data1.shape[1], dx)
            # yy = np.arange(0.0, data1.shape[0], dy)
            # xmin, xmax, ymin, ymax = np.amin(xx), np.amax(xx), np.amin(yy), np.amax(yy)
            # extent = xmin, xmax, ymin, ymax

            # cmap = 'RdBu_r'
            # cmap = 'gray'
            cmap = 'seismic'
            plt.axis('off')
            # plt.imshow(data1, extent=extent, interpolation='none', cmap=cmap, vmin=-abs_min, vmax=abs_max)
            # plt.imshow(data, interpolation='none', cmap=cmap, vmin=-abs_min, vmax=abs_max)

            # fig = plt.gcf()
            # fig.set_size_inches(2.99 / 3, 2.99 / 3)  # dpi = 300, output = 700*700 pixels
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.subplots_adjust(top=1,
                                bottom=0,
                                right=1,
                                left=0,
                                hspace=0,
                                wspace=0)
            plt.margins(0, 0)

            if blend_original_image:
                plt.imshow(data,
                           interpolation='none',
                           cmap=cmap,
                           vmin=-abs_min,
                           vmax=abs_max)
                save_filename1 = list_images[i]
                plt.savefig(save_filename1, bbox_inches='tight', pad_inches=0)
                plt.close()

                img_heatmap = cv2.imread(list_images[i])
                (tmp_height, tmp_width) = img_original.shape[:-1]
                img_heatmap = cv2.resize(img_heatmap, (tmp_width, tmp_height))
                img_heatmap_file = os.path.join(
                    os.path.dirname(list_images[i]),
                    'deepshap_{0}.jpg'.format(i))
                cv2.imwrite(img_heatmap_file, img_heatmap)

                dst = cv2.addWeighted(img_original, 0.65, img_heatmap, 0.35, 0)
                img_blend_file = os.path.join(
                    os.path.dirname(list_images[i]),
                    'deepshap_blend_{0}.jpg'.format(i))
                cv2.imwrite(img_blend_file, dst)

                #region create gif
                import imageio
                mg_paths = [
                    img_original_file, img_heatmap_file, img_blend_file
                ]
                gif_images = []
                for path in mg_paths:
                    gif_images.append(imageio.imread(path))
                img_file_gif = os.path.join(os.path.dirname(list_images[i]),
                                            'deepshap_{0}.gif'.format(i))
                imageio.mimsave(img_file_gif, gif_images, fps=gif_fps)
                list_images[i] = img_file_gif
                #endregion
            else:
                plt.imshow(data,
                           interpolation='none',
                           cmap=cmap,
                           vmin=-abs_min,
                           vmax=abs_max)
                save_filename1 = list_images[i]
                plt.savefig(save_filename1, bbox_inches='tight', pad_inches=0)
                plt.close()

        #endregion

        return list_classes, list_images
Exemplo n.º 22
0
def flat_corr_matrix(df,
                     pdf=None,
                     tight=False,
                     labels=None,
                     label_size=None,
                     size=12,
                     n_labels=3,
                     fontsize='auto',
                     draw_cbar=False,
                     tick_label_rotation=45,
                     formatter='%.2e',
                     label_rotation=45,
                     cmap='PiYG'):
    """ Draws a flat correlation matrix of df

    Args:
        df:
        pdf:
        tight:
        col_numbers:
        labels:
        label_size:
        size:
        n_labels:
        fontsize:
        draw_cbar:
        rotation:
        formatter:

    Returns:

    """

    assert isinstance(
        df, pd.DataFrame), 'Argument of wrong type! Needs pd.DataFrame'

    n_vars = np.shape(df)[1]

    fontsize = np.interp(n_vars, (0, 10),
                         (22, 10)) if fontsize is 'auto' else fontsize

    if labels is None:
        labels = df.columns
    else:
        assert len(labels) == len(
            df.columns
        ), "Numbers of labels not matching the numbers of coulums in the df"
    im = None
    fig, axes = plt.subplots(nrows=n_vars, ncols=n_vars, figsize=(size, size))

    # Plotting the matrix, iterate over the columns in 2D
    for i, row in zip(range(n_vars), axes):
        for j, ax in zip(range(n_vars), row):
            if i is j - 1000:
                plt.sca(ax)
                ax.hist(df.iloc[:, i].values, label='data', color='gray')
                ax.set_yticklabels([])
            else:
                im = flat_correlation(df.iloc[:, j],
                                      df.iloc[:, i],
                                      ax=ax,
                                      draw_labels=False,
                                      get_im=True,
                                      cmap=cmap)
            ax.xaxis.set_major_locator(plt.NullLocator())
            ax.yaxis.set_major_locator(plt.NullLocator())

    if tight:
        plt.tight_layout()

    # Common outer label
    for i, row in zip(range(n_vars), axes):
        for j, ax in zip(range(n_vars), row):
            if i == n_vars - 1:
                if label_size is not None:
                    set_flat_labels(ax,
                                    df.iloc[:, j],
                                    axis=1,
                                    n_labels=n_labels,
                                    labelsize=label_size,
                                    rotation=90 if tick_label_rotation is 0
                                    else tick_label_rotation,
                                    formatter=formatter)

                ax.set_xlabel(labels[j],
                              fontsize=fontsize,
                              rotation=label_rotation,
                              ha='right',
                              va='top')
            if j == 0:
                if label_size is not None:
                    set_flat_labels(ax,
                                    df.iloc[:, i],
                                    axis=0,
                                    n_labels=n_labels,
                                    labelsize=label_size,
                                    rotation=tick_label_rotation,
                                    formatter=formatter)
                ax.set_ylabel(labels[i],
                              fontsize=fontsize,
                              rotation=label_rotation,
                              ha='right',
                              va='bottom')

    if pdf is None:
        # plt.show()
        pass
    else:
        pdf.savefig()
        plt.close()

    if draw_cbar:
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
        cbar = plt.colorbar(
            im,
            cax=cbar_ax,
        )
        cbar.ax.set_ylabel('$\sigma$',
                           rotation=0,
                           fontsize=fontsize * 1.2,
                           va='center')
        cbar.ax.tick_params(labelsize=fontsize)