Exemple #1
0
def load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False, second=3):
    """The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with
    6000 images per class. There are 50000 training images and 10000 test images.

    The dataset is divided into five training batches and one test batch, each with
    10000 images. The test batch contains exactly 1000 randomly-selected images from
    each class. The training batches contain the remaining images in random order,
    but some training batches may contain more images from one class than another.
    Between them, the training batches contain exactly 5000 images from each class.

    Parameters
    ----------
    shape : tupe
        The shape of digit images: e.g. (-1, 3, 32, 32) , (-1, 32, 32, 3) , (-1, 32*32*3)
    plotable : True, False
        Whether to plot some image examples.
    second : int
        If ``plotable`` is True, ``second`` is the display time.

    Examples
    --------
    >>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=True)

    Note
    ------
    CIFAR-10 images can only be display without color change under uint8.
    >>> X_train = np.asarray(X_train, dtype=np.uint8)
    >>> plt.ion()
    >>> fig = plt.figure(1232)
    >>> count = 1
    >>> for row in range(10):
    >>>     for col in range(10):
    >>>         a = fig.add_subplot(10, 10, count)
    >>>         plt.imshow(X_train[count-1], interpolation='nearest')
    >>>         plt.gca().xaxis.set_major_locator(plt.NullLocator())    # 不显示刻度(tick)
    >>>         plt.gca().yaxis.set_major_locator(plt.NullLocator())
    >>>         count = count + 1
    >>> plt.draw()
    >>> plt.pause(3)

    References
    ----------
    `CIFAR website <https://www.cs.toronto.edu/~kriz/cifar.html>`_

    `Code download link <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`_

    `Code references <https://teratail.com/questions/28932>`_
    """
    import sys
    import pickle
    import numpy as np

    # We first define a download function, supporting both Python 2 and 3.
    filename = 'cifar-10-python.tar.gz'
    if sys.version_info[0] == 2:
        from urllib import urlretrieve
    else:
        from urllib.request import urlretrieve

    def download(filename, source='https://www.cs.toronto.edu/~kriz/'):
        print("Downloading %s" % filename)
        urlretrieve(source + filename, filename)

    # After downloading the cifar-10-python.tar.gz, we need to unzip it.
    import tarfile

    def un_tar(file_name):
        print("Extracting %s" % file_name)
        tar = tarfile.open(file_name)
        names = tar.getnames()
        # if os.path.isdir(file_name + "_files"):
        #     pass
        # else:
        #     os.mkdir(file_name + "_files")
        for name in names:
            tar.extract(name)  #, file_name.split('.')[0])
        tar.close()
        print("Extracted to %s" % names[0])

    if not os.path.exists('cifar-10-batches-py'):
        download(filename)
        un_tar(filename)

    def unpickle(file):
        fp = open(file, 'rb')
        if sys.version_info.major == 2:
            data = pickle.load(fp)
        elif sys.version_info.major == 3:
            data = pickle.load(fp, encoding='latin-1')
        fp.close()
        return data

    X_train = None
    y_train = []

    path = ''  # you can set a dir to the data here.

    for i in range(1, 6):
        data_dic = unpickle(path +
                            "cifar-10-batches-py/data_batch_{}".format(i))
        if i == 1:
            X_train = data_dic['data']
        else:
            X_train = np.vstack((X_train, data_dic['data']))
        y_train += data_dic['labels']

    test_data_dic = unpickle(path + "cifar-10-batches-py/test_batch")
    X_test = test_data_dic['data']
    y_test = np.array(test_data_dic['labels'])

    if shape == (-1, 3, 32, 32):
        X_test = X_test.reshape(shape)
        X_train = X_train.reshape(shape)
        # X_train = np.transpose(X_train, (0, 1, 3, 2))
    elif shape == (-1, 32, 32, 3):
        X_test = X_test.reshape(shape, order='F')
        X_train = X_train.reshape(shape, order='F')
        X_test = np.transpose(X_test, (0, 2, 1, 3))
        X_train = np.transpose(X_train, (0, 2, 1, 3))
    else:
        X_test = X_test.reshape(shape)
        X_train = X_train.reshape(shape)

    y_train = np.array(y_train)

    if plotable == True:
        print('\nCIFAR-10')
        import matplotlib.pyplot as plt
        fig = plt.figure(1)

        print('Shape of a training image: X_train[0]', X_train[0].shape)

        plt.ion()  # interactive mode
        count = 1
        for row in range(10):
            for col in range(10):
                a = fig.add_subplot(10, 10, count)
                if shape == (-1, 3, 32, 32):
                    # plt.imshow(X_train[count-1], interpolation='nearest')
                    plt.imshow(np.transpose(X_train[count - 1], (1, 2, 0)),
                               interpolation='nearest')
                    # plt.imshow(np.transpose(X_train[count-1], (2, 1, 0)), interpolation='nearest')
                elif shape == (-1, 32, 32, 3):
                    plt.imshow(X_train[count - 1], interpolation='nearest')
                    # plt.imshow(np.transpose(X_train[count-1], (1, 0, 2)), interpolation='nearest')
                else:
                    raise Exception(
                        "Do not support the given 'shape' to plot the image examples"
                    )
                plt.gca().xaxis.set_major_locator(
                    plt.NullLocator())  # 不显示刻度(tick)
                plt.gca().yaxis.set_major_locator(plt.NullLocator())
                count = count + 1
        plt.draw()  # interactive mode
        plt.pause(3)  # interactive mode

        print("X_train:", X_train.shape)
        print("y_train:", y_train.shape)
        print("X_test:", X_test.shape)
        print("y_test:", y_test.shape)

    X_train = np.asarray(X_train, dtype=np.float32)
    X_test = np.asarray(X_test, dtype=np.float32)
    y_train = np.asarray(y_train, dtype=np.int32)
    y_test = np.asarray(y_test, dtype=np.int32)

    return X_train, y_train, X_test, y_test
Exemple #2
0
def labelit(img, mask_part, labelbase, rois):
    """
    人工给每部分重新打标签
    :param img: 图片
    :param mask_part: 标签
    :return: 重打过后的mask
    """
    n_part = len(mask_part)
    colors = random_colors(n_part)
    minrow = rois[0]
    mincol = rois[1]
    maxrow = rois[2]
    maxcol = rois[3]
    newmask = np.zeros(mask_part[0].shape)
    newmask = newmask.astype(np.int32)
    print("按从左到右顺序输入标签:")

    for i in range(n_part):
        mask = mask_part[i]
        temp = copy.copy(img)
        onepart = apply_mask(temp, mask, colors[i])
        onepart = onepart[minrow:maxrow, mincol:maxcol, ]

        # print(i,onepart.shape)
        ax = plt.subplot(1, n_part, i + 1)
        # print(ax)
        plt.imshow(onepart)
        ax.set_title('part')
        plt.xticks([]), plt.yticks([])
        plt.tight_layout()
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.margins(0, 0)
    plt.show()

    labels = str(input())
    if labels[0] == '.':
        return []
    # if labels[0]=='r' or labels[0]=='R':
    #     for i in range(n_part):
    #         mask = mask_part[i]
    #         temp = copy.copy(img)
    #         onepart = apply_mask(temp, mask, colors[i])
    #         onepart = onepart[minrow:maxrow, mincol:maxcol, ]
    #         cv2.imwrite(str(i)+'.png',onepart)
    #     for i in range(n_part):
    #         onepart=cv2.imread(str(i)+'.png')
    #         ax = plt.subplot(1, n_part, i + 1)
    #         plt.imshow(onepart)
    #         ax.set_title('part')
    #         plt.xticks([]), plt.yticks([])
    #         plt.tight_layout()
    #     plt.gca().xaxis.set_major_locator(plt.NullLocator())
    #     plt.gca().yaxis.set_major_locator(plt.NullLocator())
    #     plt.margins(0, 0)
    #     plt.show()
    for i in range(n_part):
        if int(labels[i]) == 0:
            newmask = np.where(mask_part[i] > 0, np.zeros(newmask.shape),
                               newmask)  # 忽然发现这句好像没用?
        else:
            newmask = np.where(
                mask_part[i] > 0,
                np.ones(newmask.shape, dtype=int) *
                (labelbase * 10 + int(labels[i])), newmask)

    # 一部分一部分标
    # for mask in mask_part:
    #     plt.figure(figsize=(10, 10))
    #     print("输入这部分的标签:")
    #     temp=copy.copy(img)
    #     onepart=apply_mask(temp,mask,colors[0])
    #     plt.subplot(1,1,1)
    #     plt.imshow(onepart)
    #     plt.title("输入这部分的标签")
    #     plt.xticks([]), plt.yticks([])
    #     # plt.show()
    #     plt.ion()
    #     plt.pause(2)
    #     plt.close()
    #     newlabel=int(input())
    #     newmask=np.where(mask>0,newlabel,newmask)
    return newmask
def show_array_broadcasting(a, b, filename=None):
    """Visualize broadcasting of arrays"""

    c = a + b

    fig, axes = plt.subplots(1, 3, figsize=(12, 4))

    data = a
    ax = axes[0]
    ax.patch.set_facecolor('black')
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())
    colors = ['#1199ff', '#ee3311', '#66ff22']
    for (m, n), w in np.ndenumerate(data):
        size = 0.97
        color = '#1199ff'
        rect = plt.Rectangle([n - size / 2, m - size / 2],
                             size, size,
                             facecolor=color,
                             edgecolor=color)
        ax.add_patch(rect)
        ax.text(m, n, "%d" % data[n, m], ha='center', va='center', fontsize=12)
    ax.text(2.8, 1, "+", ha='center', va='center', fontsize=22)
    ax.autoscale_view()
    ax.invert_yaxis()

    data = np.zeros_like(a) + b
    ax = axes[1]
    ax.patch.set_facecolor('black')
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())
    colors = ['#1199ff', '#ee3311', '#66ff22']
    for (m, n), w in np.ndenumerate(data):
        size = 0.97
        color = '#eeeeee'
        rect = plt.Rectangle([n - size / 2, m - size / 2],
                             size, size,
                             facecolor=color,
                             edgecolor=color)
        ax.add_patch(rect)
        if (np.argmax(b.T.shape) == 0 and m == 0) or (np.argmax(b.T.shape) == 1 and n == 0):
            color = '#1199ff'
            #size = 0.8
            rect = plt.Rectangle([n - size / 2, m - size / 2],
                                 size, size,
                                 facecolor=color,
                                 edgecolor=color)
            ax.add_patch(rect)
        ax.text(m, n, "%d" % data[n, m], ha='center', va='center', fontsize=12)
    ax.text(2.8, 1, "=", ha='center', va='center', fontsize=22)
    ax.autoscale_view()
    ax.invert_yaxis()

    data = c
    ax = axes[2]
    ax.patch.set_facecolor('black')
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())
    colors = ['#1199ff', '#ee3311', '#66ff22']
    for (m, n), w in np.ndenumerate(data):
        size = 0.97
        color = '#1199ff' if w > 0 else '#eeeeee'
        color = '#eeeeee'
        rect = plt.Rectangle([n - size / 2, m - size / 2],
                             size, size,
                             facecolor=color,
                             edgecolor=color)
        ax.add_patch(rect)
        color = '#1199ff'
        #size = 0.8
        rect = plt.Rectangle([n - size / 2, m - size / 2],
                             size, size,
                             facecolor=color,
                             edgecolor=color)
        ax.add_patch(rect)
        ax.text(m, n, "%d" % data[n, m], ha='center', va='center', fontsize=12)
    ax.autoscale_view()
    ax.invert_yaxis()

    # fig.tight_layout()

    if filename:
        fig.savefig(filename + ".png", dpi=200)
        fig.savefig(filename + ".svg")
        fig.savefig(filename + ".pdf")
Exemple #4
0
def plot_marginal_pdf(data_folder,
                      C_limits,
                      mirror_limits,
                      num_bin_joint,
                      params_names,
                      plot_folder,
                      smooth=''):

    N_params = len(C_limits)
    max_value, max_value2 = 0.0, 0.0
    data = dict()
    for i in range(N_params):
        for j in range(N_params):
            if i < j:
                data[str(i) + str(j)] = np.loadtxt(
                    os.path.join(data_folder, f'marginal_{smooth}{i}{j}'))
                max_value = max(max_value, np.max(data[str(i) + str(j)]))
            if i > j and smooth:
                data[str(i) + str(j)] = np.loadtxt(
                    os.path.join(data_folder,
                                 'conditional_smooth{}{}'.format(i, j)))
                norm = np.sum(data[str(i) + str(j)])
                # print('norm = ', norm, np.max(data[str(i) + str(j)]))
                data[str(i) + str(j)] /= norm
                max_value2 = max(max_value2, np.max(data[str(i) + str(j)]))
            # print(max_value, max_value2)

    # max_value = int(max_value)
    # cmap = plt.cm.jet  # define the colormap
    # cmaplist = [cmap(i) for i in range(cmap.N)]  # extract all colors from the .jet map
    # # cmaplist[0] = 'black'   # force the first color entry to be black
    # cmaplist[0] = 'white' # force the first color entry to be white
    # cmap = cmap.from_list('Custom cmap', cmaplist, max_value)

    ###################################################################################################
    cmap2 = plt.cm.inferno  # define the colormap
    cmap = plt.cm.BuPu  # define the colormap
    cmaplist = [cmap(i) for i in reversed(range(cmap.N))
                ]  # extract all colors from the map
    cmaplist2 = [cmap2(i)
                 for i in (range(cmap2.N))]  # extract all colors from the map
    gamma1 = 1
    gamma2 = 1
    ###################################################################################################
    # cmap2 = plt.cm.Greys   # define the colormap
    # cmap = plt.cm.Greys  # define the colormap
    # cmaplist = [cmap(i) for i in (range(cmap.N))]  # extract all colors from the map
    # cmaplist2 = [cmap2(i) for i in (range(cmap2.N))]  # extract all colors from the map
    # gamma1 = 1
    # gamma2 = 1
    ###################################################################################################
    cmaplist[0] = 'black'  # 'white' # force the first color entry to be white
    cmaplist2[0] = 'black'  # 'white' # force the first color entry to be white
    # cmaplist[0] = 'white' # force the first color entry to be white
    # cmaplist2[0] = 'white'  # force the first color entry to be white
    cmap = colors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist)
    cmap2 = colors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist2)

    width, height = fig_size(double_column)
    # confidence = np.loadtxt(os.path.join(data_folder, 'confidence_75'))
    print(width, height)
    fig = plt.figure(figsize=(width, height))
    for i in range(N_params):
        for j in range(N_params):
            if i == j:
                data_marg = np.loadtxt(
                    os.path.join(data_folder, f'marginal_{smooth}{i}'))
                ax = plt.subplot2grid((N_params, N_params), (i, i))
                ax.plot(data_marg[0], data_marg[1])
                if smooth:
                    c_final_smooth = np.loadtxt(
                        os.path.join(data_folder,
                                     f'C_final_{smooth}{num_bin_joint}'))
                    # ax.axvline(confidence[i, 0], linestyle='--', color='b', label=r'$75\%$ interval')
                    # ax.axvline(confidence[i, 1], linestyle='--', color='b')
                    if len(c_final_smooth.shape) == 1:
                        ax.axvline(c_final_smooth[i],
                                   linestyle='--',
                                   color='r',
                                   label='max of joint pdf')
                    elif len(c_final_smooth) < 4:
                        for C in c_final_smooth:
                            ax.axvline(C[i],
                                       linestyle='--',
                                       color='r',
                                       label='joint max')
                ax.axis(xmin=C_limits[i, 0], xmax=C_limits[i, 1], ymin=0)

                # if i == 0:
                #     ax.xaxis.set_major_locator(ticker.MultipleLocator(0.5))
                #     ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.1))
                # if i == 1:
                #     ax.xaxis.set_major_locator(ticker.MultipleLocator(0.5))
                #     ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.1))
                # if i == 2:
                #     ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.1))
                # if i == 3:
                #     ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.1))

                ax.yaxis.set_major_formatter(plt.NullFormatter())
                ax.yaxis.set_major_locator(plt.NullLocator())
                if i != N_params - 1:
                    ax.xaxis.set_major_formatter(plt.NullFormatter())
                else:
                    ax.set_xlabel(params_names[i], labelpad=2)
                ax.tick_params(axis='both', which='minor', direction='in')
                ax.tick_params(axis='both', which='major', pad=0.8)

                # if i == 0:
                #     ax.legend(bbox_to_anchor=(3, -2.75), fancybox=True)
                #     textstr = '\n'.join((
                #         r'$C_1=%.3f$' % (c_final_smooth[0],),
                #         r'$C_2=%.3f$' % (c_final_smooth[1],),
                #         r'$C_{\epsilon1}=%.3f$' % (c_final_smooth[2],),
                #         r'$C_{\epsilon2}=%.3f$' % (c_final_smooth[3],)))
                #     ax.text(0.15, -1.6, textstr, transform=ax.transAxes, fontsize=12,
                #             verticalalignment='top', linespacing=1.5)
            elif i < j:
                ax = plt.subplot2grid((N_params, N_params), (i, j))
                ax.axis(xmin=C_limits[j, 0],
                        xmax=C_limits[j, 1],
                        ymin=C_limits[i, 0],
                        ymax=C_limits[i, 1])
                # ax.xaxis.set_minor_locator(ticker.MultipleLocator(0.1))
                # ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1))
                # ax.xaxis.set_major_locator(ticker.MultipleLocator(0.02))
                # ax.yaxis.set_major_locator(ticker.MultipleLocator(0.02))
                if j == N_params - 1:
                    ax.yaxis.tick_right()
                    ax.yaxis.set_label_position("right")
                    ax.set_ylabel(params_names[i], labelpad=2)
                else:
                    ax.yaxis.set_major_formatter(plt.NullFormatter())

                ax.xaxis.set_major_formatter(plt.NullFormatter())
                ax.yaxis.set_tick_params(direction='in')

                ax.tick_params(axis='both', which='minor', direction='in')
                ax.tick_params(axis='both', which='major', pad=0.8)
                ext = (mirror_limits[j, 0], mirror_limits[j, 1],
                       mirror_limits[i, 0], mirror_limits[i, 1])

                im = ax.imshow(data[str(i) + str(j)],
                               origin='lower',
                               cmap=cmap,
                               aspect='auto',
                               extent=ext,
                               vmin=0,
                               vmax=max_value)
            elif i > j and smooth:
                ax = plt.subplot2grid((N_params, N_params), (i, j))
                ax.axis(xmin=C_limits[j, 0],
                        xmax=C_limits[j, 1],
                        ymin=C_limits[i, 0],
                        ymax=C_limits[i, 1])
                ext = (mirror_limits[j, 0], mirror_limits[j, 1],
                       mirror_limits[i, 0], mirror_limits[i, 1])
                ax.tick_params(axis='both', which='major', pad=2)
                # ax.xaxis.set_major_locator(ticker.MultipleLocator(ticks[j]))
                # ax.yaxis.set_major_locator(ticker.MultipleLocator(ticks[i]))
                # if True and j != 2:
                #     ax.text(0.02, 0.07, '(' + string.ascii_lowercase[i * N_params + j] + ')',
                #             transform=ax.transAxes, size=10, weight='black')
                # else:
                #     ax.text(0.02, 0.85, '('+string.ascii_lowercase[i*N_params+j]+')',
                #         transform=ax.transAxes, size=10, weight='black')
                if j == 0:
                    ax.set_ylabel(params_names[i])
                else:
                    ax.yaxis.set_major_formatter(plt.NullFormatter())
                if i != (N_params - 1):
                    ax.xaxis.set_major_formatter(plt.NullFormatter())
                else:
                    ax.set_xlabel(params_names[j], labelpad=2)

                im_cond = ax.imshow(data[str(i) + str(j)],
                                    origin='lower',
                                    cmap=cmap2,
                                    aspect='auto',
                                    extent=ext,
                                    norm=colors.PowerNorm(gamma=gamma2),
                                    vmax=max_value2)
                # ax.axvline(c_final_smooth[j], linestyle='--', color='r')
                # ax.axhline(c_final_smooth[i], linestyle='--', color='r')
    # cax = plt.axes([0.05, 0.1, 0.01, 0.26])
    # plt.colorbar(im, cax=cax)   #, ticks=np.arange(max_value+1))

    fig.subplots_adjust(left=0.20,
                        right=0.80,
                        wspace=0.1,
                        hspace=0.1,
                        bottom=0.1,
                        top=0.98)
    fig.savefig(os.path.join(plot_folder, f'marginal_{smooth}'))
    plt.close('all')
    elif x == -1:
        return '-h'
    else:
        return '%ih' % x


for i, kernel in enumerate(
    ['gaussian', 'tophat', 'epanechnikov', 'exponential', 'linear', 'cosine']):
    axi = ax.ravel()[i]
    log_dens = KernelDensity(kernel=kernel).fit(x_src).score_samples(x_plot)
    axi.fill(x_plot[:, 0], np.exp(log_dens), '-k', fc='#AAAAFF')
    axi.text(-2.6, 0.95, kernel)

    axi.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
    axi.xaxis.set_major_locator(plt.MultipleLocator(1))
    axi.yaxis.set_major_locator(plt.NullLocator())

    axi.set_ylim(0, 1.05)
    axi.set_xlim(-2.9, 2.9)

ax[0, 1].set_title('Available Kernels')

plt.show()

# ---------------------------------------------------------------
# Plot a 1D density example
N = 100
np.random.seed(1)
x = np.concatenate((np.random.normal(0, 1, int(0.3 * N)),
                    np.random.normal(5, 1, int(0.7 * N))))[:, np.newaxis]
x_plot = np.linspace(-5, 10, 1000)[:, np.newaxis]
Exemple #6
0
def fig_lfp_scaling(fig,
                    params,
                    bottom=0.55,
                    top=0.95,
                    channels=[0, 3, 7, 11, 13],
                    T=[800., 1000.],
                    Df=None,
                    mlab=True,
                    NFFT=256,
                    noverlap=128,
                    window=plt.mlab.window_hanning,
                    letters='ABCD',
                    lag=20,
                    show_titles=True,
                    show_xlabels=True):

    fname_fullscale = os.path.join(params.savefolder, 'LFPsum.h5')
    fname_downscaled = os.path.join(params.savefolder, 'populations',
                                    'subsamples', 'LFPsum_10_0.h5')

    # ana_params.set_PLOS_2column_fig_style(ratio=0.5)

    gs = gridspec.GridSpec(len(channels), 8, bottom=bottom, top=top)

    # fig = plt.figure()
    # fig.subplots_adjust(left=0.075, right=0.95, bottom=0.075, wspace=0.8, hspace=0.1)

    scaling_factor = np.sqrt(10)

    ##################################
    ###  LFP traces                ###
    ##################################

    ax = fig.add_subplot(gs[:, :3])

    phlp.annotate_subplot(ax,
                          ncols=8 / 3.,
                          nrows=1,
                          letter=letters[0],
                          linear_offset=0.065)
    plot_signal_sum(
        ax,
        params,
        fname=os.path.join(params.savefolder, 'LFPsum.h5'),
        unit='mV',
        scaling_factor=1.,
        scalebar=True,
        vlimround=None,
        T=T,
        ylim=[-1600, 50],
        color='k',
        label='$\Phi$',
        rasterized=False,
        zorder=1,
    )
    plot_signal_sum(
        ax,
        params,
        fname=os.path.join(params.savefolder, 'populations', 'subsamples',
                           'LFPsum_10_0.h5'),
        unit='mV',
        scaling_factor=scaling_factor,
        scalebar=False,
        vlimround=None,
        T=T,
        ylim=[-1600, 50],
        color='gray' if analysis_params.bw else analysis_params.colorP,
        label='$\hat{\Phi}^{\prime}$',
        rasterized=False,
        lw=1,
        zorder=0)

    if show_titles:
        ax.set_title('LFP & low-density predictor')
    if show_xlabels:
        ax.set_xlabel('$t$ (ms)', labelpad=0.)
    else:
        ax.set_xlabel('')

    #################################
    ### Correlations              ###
    #################################

    ax = fig.add_subplot(gs[:, 3])
    phlp.annotate_subplot(ax,
                          ncols=8,
                          nrows=1,
                          letter=letters[1],
                          linear_offset=0.065)
    phlp.remove_axis_junk(ax)

    datas = []
    files = [
        os.path.join(params.savefolder, 'LFPsum.h5'),
        os.path.join(params.savefolder, 'populations', 'subsamples',
                     'LFPsum_10_0.h5')
    ]
    for fil in files:
        f = h5py.File(fil)
        datas.append(f['data'].value[:, 200:])
        f.close()

    zvec = np.r_[params.electrodeParams['z']]
    cc = np.zeros(len(zvec))
    for ch in np.arange(len(zvec)):
        x0 = datas[0][ch]
        x0 -= x0.mean()
        x1 = datas[1][ch]
        x1 -= x1.mean()
        cc[ch] = np.corrcoef(x0, x1)[1, 0]

    ax.barh(zvec, cc, height=80, align='center', color='0.5', linewidth=0.5)

    # superimpose the chance level, obtained by mixing one input vector N times
    # while keeping the other fixed. We show boxes drawn left to right where
    # these denote mean +/- two standard deviations.
    N = 1000
    method = 'randphase'  #or 'permute'
    chance = np.zeros((cc.size, N))
    for ch in np.arange(len(zvec)):
        x1 = datas[1][ch]
        x1 -= x1.mean()
        if method == 'randphase':
            x0 = datas[0][ch]
            x0 -= x0.mean()
            X00 = np.fft.fft(x0)
        for n in range(N):
            if method == 'permute':
                x0 = np.random.permutation(datas[0][ch])
            elif method == 'randphase':
                X0 = np.copy(X00)
                #random phase information such that spectra is preserved
                theta = np.random.uniform(0, 2 * np.pi, size=X0.size // 2)
                #half-sided real and imaginary component
                real = abs(X0[1:X0.size // 2 + 1]) * np.cos(theta)
                imag = abs(X0[1:X0.size // 2 + 1]) * np.sin(theta)

                #account for the antisymmetric phase values
                X0.imag[1:imag.size + 1] = imag
                X0.imag[imag.size + 1:] = -imag[::-1]
                X0.real[1:real.size + 1] = real
                X0.real[real.size + 1:] = real[::-1]
                x0 = np.fft.ifft(X0).real

            chance[ch, n] = np.corrcoef(x0, x1)[1, 0]

    # p-values, compute the fraction of chance correlations > cc at each channel
    p = []
    for i, x in enumerate(cc):
        p += [(chance[i, ] >= x).sum() / float(N)]

    print('p-values:', p)

    #compute the 99% percentile of the chance data
    right = np.percentile(chance, 99, axis=-1)

    ax.plot(right, zvec, ':', color='k', lw=1.)
    ax.set_ylim([-1550, 50])
    ax.set_yticklabels([])
    ax.set_yticks(zvec)
    ax.set_xlim([0., 1.])
    ax.set_xticks([0.0, 0.5, 1])
    ax.yaxis.tick_left()

    if show_titles:
        ax.set_title('corr.\ncoef.')
    if show_xlabels:
        ax.set_xlabel('$cc$ (-)', labelpad=0.)

    ##################################
    ###  Single channel PSDs       ###
    ##################################

    freqs, PSD_fullscale = calc_signal_power(params,
                                             fname=fname_fullscale,
                                             transient=200,
                                             Df=Df,
                                             mlab=mlab,
                                             NFFT=NFFT,
                                             noverlap=noverlap,
                                             window=window)
    freqs, PSD_downscaled = calc_signal_power(params,
                                              fname=fname_downscaled,
                                              transient=200,
                                              Df=Df,
                                              mlab=mlab,
                                              NFFT=NFFT,
                                              noverlap=noverlap,
                                              window=window)
    inds = freqs >= 1  # frequencies greater than 4 Hz

    for i, ch in enumerate(channels):

        ax = fig.add_subplot(gs[i, 4:6])
        if i == 0:
            phlp.annotate_subplot(ax,
                                  ncols=8 / 2.,
                                  nrows=len(channels),
                                  letter=letters[2],
                                  linear_offset=0.065)
        phlp.remove_axis_junk(ax)
        ax.loglog(
            freqs[inds],
            PSD_fullscale[ch][inds],
            color='k',
            label='$\gamma=1.0$',
            zorder=1,
        )
        ax.loglog(
            freqs[inds],
            PSD_downscaled[ch][inds] * scaling_factor**2,
            lw=1,
            color='gray' if analysis_params.bw else analysis_params.colorP,
            label='$\gamma=0.1, \zeta=\sqrt{10}$',
            zorder=0,
        )
        ax.loglog(freqs[inds],
                  PSD_downscaled[ch][inds] * scaling_factor**4,
                  lw=1,
                  color='0.75',
                  label='$\gamma=0.1, \zeta=10$',
                  zorder=0)
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')
        ax.text(0.8,
                0.9,
                'ch. %i' % (ch + 1),
                horizontalalignment='left',
                verticalalignment='center',
                fontsize=6,
                transform=ax.transAxes)
        ax.yaxis.set_minor_locator(plt.NullLocator())
        if i < len(channels) - 1:
            #ax.set_xticks([])
            ax.set_xticklabels([])
        ax.tick_params(axis='y', which='minor', bottom='off')
        ax.set_xlim([4E0, 4E2])
        ax.set_ylim([3E-8, 1E-4])
        if i == 0:
            ax.tick_params(axis='y', which='major', pad=0)
            ax.set_ylabel('(mV$^2$/Hz)', labelpad=0.)
            if show_titles:
                ax.set_title('power spectra')
        #ax.set_yticks([1E-9,1E-7,1E-5])
        if i > 0:
            ax.set_yticklabels([])

    if show_xlabels:
        ax.set_xlabel(r'$f$ (Hz)', labelpad=0.)

    ##################################
    ###  PSD ratios                ###
    ##################################

    ax = fig.add_subplot(gs[:, 6:8])
    phlp.annotate_subplot(ax,
                          ncols=8. / 2,
                          nrows=1,
                          letter=letters[3],
                          linear_offset=0.065)

    PSD_ratio = PSD_fullscale / (PSD_downscaled * scaling_factor**2)
    zvec = np.r_[params.electrodeParams['z']]
    zvec = np.r_[zvec, zvec[-1] + np.diff(zvec)[-1]]
    inds = freqs >= 1  # frequencies greater than 4 Hz

    im = ax.pcolormesh(freqs[inds],
                       zvec + 40,
                       PSD_ratio[:, inds],
                       rasterized=False,
                       cmap=plt.get_cmap('gray_r', 18)
                       if analysis_params.bw else plt.cm.get_cmap('Reds', 18),
                       vmin=1E0,
                       vmax=1.E1)
    ax.set_xlim([4E0, 4E2])
    ax.set_xscale('log')
    ax.set_yticks(zvec)
    yticklabels = ['ch. %i' % i for i in np.arange(len(zvec)) + 1]
    ax.set_yticklabels(yticklabels)
    plt.axis('tight')
    cb = phlp.colorbar(fig,
                       ax,
                       im,
                       width=0.05,
                       height=0.5,
                       hoffset=-0.05,
                       voffset=0.0)
    cb.set_label('(-)', labelpad=0.)
    phlp.remove_axis_junk(ax)

    if show_titles:
        ax.set_title('power ratio')
    if show_xlabels:
        ax.set_xlabel(r'$f$ (Hz)', labelpad=0.)

    return fig
Exemple #7
0
    def make_plot(self, groups, out_fn, quantitative_proteomic_data=False):
        '''
        plot the interaction matrix
        @param groups is the list of groups of domains, eg,
                      [["protA_1-10","prot1A_11-100"],["protB"]....]
                      it will plot a space between different groups
        @param out_fn name of the plot file
        @param quantitative_proteomic_data plot the quantitative proteomic data
        '''
        import numpy as np
        import matplotlib.pyplot as plt
        from matplotlib import cm

        ax = plt.gca()
        ax.set_aspect('equal', 'box')
        ax.xaxis.set_major_locator(plt.NullLocator())
        ax.yaxis.set_major_locator(plt.NullLocator())

        largespace = 0.6
        smallspace = 0.5
        squaredistance = 1.0
        squaresize = 0.99
        domain_xlocations = {}
        domain_ylocations = {}

        xoffset = squaredistance
        yoffset = squaredistance
        xlabels = []
        ylabels = []
        for group in groups:
            xoffset += largespace
            yoffset += largespace
            for subgroup in group:
                xoffset += smallspace
                yoffset += smallspace
                for domain in subgroup:
                    domain_xlocations[domain] = xoffset
                    domain_ylocations[domain] = yoffset
                    #rect = plt.Rectangle([xoffset- squaresize / 2, yoffset - squaresize / 2], squaresize, squaresize,
                    #                     facecolor=(1,1,1), edgecolor=(0.1,0.1,0.1))

                    #ax.add_patch(rect)
                    #ax.text(xoffset , yoffset ,domain,horizontalalignment='left',verticalalignment='center',rotation=-45.0)
                    xoffset += squaredistance
                    yoffset += squaredistance

        for edge, count in self.edges.items():

            if quantitative_proteomic_data:
                #normalize
                maxqpd = max(self.quantitative_proteomic_data.values())
                minqpd = min(self.quantitative_proteomic_data.values())
                if edge in self.quantitative_proteomic_data:
                    value = self.quantitative_proteomic_data[edge]
                elif (edge[1], edge[0]) in self.quantitative_proteomic_data:
                    value = self.quantitative_proteomic_data[(edge[1],
                                                              edge[0])]
                else:
                    value = 0.0
                print(minqpd, maxqpd)
                density = (1.0 - (value - minqpd) / (maxqpd - minqpd))
            else:
                density = (1.0 - float(count) / self.num_rmf)
            color = (density, density, 1.0)
            x = domain_xlocations[edge[0]]
            y = domain_ylocations[edge[1]]
            if x > y:
                xtmp = y
                ytmp = x
                x = xtmp
                y = ytmp
            rect = plt.Rectangle([x - squaresize / 2, y - squaresize / 2],
                                 squaresize,
                                 squaresize,
                                 facecolor=color,
                                 edgecolor='Gray',
                                 linewidth=0.1)
            ax.add_patch(rect)
            rect = plt.Rectangle([y - squaresize / 2, x - squaresize / 2],
                                 squaresize,
                                 squaresize,
                                 facecolor=color,
                                 edgecolor='Gray',
                                 linewidth=0.1)
            ax.add_patch(rect)

        ax.autoscale_view()
        plt.savefig(out_fn)
        plt.show()
        exit()
def plot_aligned_img(x1, y1, w1, x2, y2, w2, fig_name, intersection, id1, id2,
                     save_dir, bg=True):
    '''
    :param x1: x list
    :param y1: y list
    :param w1: line width
    :param x2:
    :param y2:
    :param w2:
    :param fig_name: figure name to save
    :param intersection: (x,y) of intersection
    :param id1: id of the intersection point in x1
    :param id2:
    :param save_dir: directory to save figure
    :param bg: if plot background
    :return: save the figure
    '''
    d = 8  # fig size
    r = 40  # range of x and y
    al = 8  # arrow length
    dpi = 100
    if w1 is None:
        w1 = 0.1
    if w2 is None:
        w2 = 0.1
    if id1 is None:
        id1 = find_id(intersection, x1, y1)
    if id2 is None:
        id2 = find_id(intersection, x2, y2)
    fig, axes = plt.subplots(1, 1, figsize=(d, d), dpi=dpi)
    if bg:
        lanelet_map_file = "D:/Downloads/INTERACTION-Dataset-DR-v1_0/maps/DR_USA_Roundabout_EP.osm"
        map_vis_without_lanelet.draw_map_without_lanelet(lanelet_map_file, axes, 0, 0)
    else:
        # set bg to black
        axes.patch.set_facecolor("k")
    circle = patches.Circle(intersection, (w1 + w2) * 6 * rate * d / r, color='r', zorder=3)
    axes.add_patch(circle)
    # calculate the k as the tangent
    if id1+1 >= len(x1):
        delta_y1, delta_x1 = y1[id1] - y1[id1 - 1], x1[id1] - x1[id1 - 1]
    elif id1-1 < 0:
        delta_y1, delta_x1 = y1[id1 + 1] - y1[id1], x1[id1 + 1] - x1[id1]
    else:
        delta_y1, delta_x1 = y1[id1 + 1] - y1[id1 - 1], x1[id1 + 1] - x1[id1 - 1]
    theta1 = math.atan2(delta_y1, delta_x1)
    # convert from -pi~pi to 0~2pi
    if theta1 < 0:
        theta1 += 2*math.pi
    if id2+1 >= len(x2):
        delta_y2, delta_x2 = (y2[id2] - y2[id2 - 1]), (x2[id2] - x2[id2 - 1])
    elif id2-1 < 0:
        delta_y2, delta_x2 = (y2[id2 + 1] - y2[id2]), (x2[id2 + 1] - x2[id2])
    else:
        delta_y2, delta_x2 = (y2[id2 + 1] - y2[id2 - 1]), (x2[id2 + 1] - x2[id2 - 1])
    theta2 = math.atan2(delta_y2, delta_x2)
    # convert from -pi~pi to 0~2pi
    if theta2 < 0:
        theta2 += 2*math.pi
    if bg:
        # before rotation
        plt.plot(x1, y1, linewidth=w1 * 72 * rate * d // r, color='b')
        plt.plot(x2, y2, linewidth=w2 * 72 * rate * d // r, color='g')
        delta_xy1 = (delta_x1 ** 2 + delta_y1 ** 2) ** 0.5
        ar_x1, ar_y1 = delta_x1 / delta_xy1, delta_y1 / delta_xy1
        axes.arrow(x1[id1], y1[id1], al * ar_x1, al * ar_y1, zorder=30, color='purple', width=0.2, head_width=0.6)
        delta_xy2 = (delta_x2 ** 2 + delta_y2 ** 2) ** 0.5
        ar_x2, ar_y2 = delta_x2 / delta_xy2, delta_y2 / delta_xy2
        axes.arrow(x2[id2], y2[id2], al * ar_x2, al * ar_y2, zorder=30, color='yellow', width=0.2, head_width=0.6)

    # theta of angle bisector
    avg_theta = (theta1+theta2)/2
    theta1_rot = theta1 - avg_theta
    k1_rot = math.tan(theta1_rot)
    theta2_rot = theta2 - avg_theta
    k2_rot = math.tan(theta2_rot)

    # draw the arrow whose length is al
    # rotate according to angle bisector to align
    # x1_rot, y1_rot = counterclockwise_rotate(x1, y1, intersection, -avg_theta)
    # x2_rot, y2_rot = counterclockwise_rotate(x2, y2, intersection, -avg_theta)
    # plt.plot(x1_rot, y1_rot, linewidth=w1 * 72 * rate * d // r, color='b')
    # plt.plot(x2_rot, y2_rot, linewidth=w2 * 72 * rate * d // r, color='g')
    if bg:
        pass
        # axes.arrow(x1[id1], y1[id1], al/(k1_rot**2+1)**0.5, al * k1_rot/(k1_rot**2+1)**0.5, zorder=4,
        #            color='purple', width=0.2, head_width=0.6)
        # axes.arrow(x2[id2], y2[id2], al/(k2_rot**2+1)**0.5, al * k2_rot/(k2_rot**2+1)**0.5, zorder=5,
        #            color='yellow', width=0.2, head_width=0.6)

    # set x y range
    plt.xlim(intersection[0]-r//2, intersection[0]+r//2)
    plt.ylim(intersection[1]-r//2, intersection[1]+r//2)

    # remove the white frame
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
    plt.margins(0, 0)

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    if bg:
        fig_name = fig_name + '_bg'
    plt.savefig(save_dir+'{}.png'.format(fig_name))
    plt.close()
    return avg_theta
def plot_state_hinton(rho, title='', figsize=None):
    """Plot a hinton diagram for the quanum state.

    Args:
        rho (ndarray): Numpy array for state vector or density matrix.
        title (str): a string that represents the plot title
        figsize (tuple): Figure size in inches.
    Returns:
         matplotlib.Figure: The matplotlib.Figure of the visualization

    Raises:
        ImportError: Requires matplotlib.
    """
    if not HAS_MATPLOTLIB:
        raise ImportError('Must have Matplotlib installed.')
    rho = _validate_input_state(rho)
    if figsize is None:
        figsize = (8, 5)
    num = int(np.log2(len(rho)))
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    max_weight = 2**np.ceil(np.log(np.abs(rho).max()) / np.log(2))
    datareal = np.real(rho)
    dataimag = np.imag(rho)
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
    lx = len(datareal[0])  # Work out matrix dimensions
    ly = len(datareal[:, 0])
    # Real
    ax1.patch.set_facecolor('gray')
    ax1.set_aspect('equal', 'box')
    ax1.xaxis.set_major_locator(plt.NullLocator())
    ax1.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in np.ndenumerate(datareal):
        color = 'white' if w > 0 else 'black'
        size = np.sqrt(np.abs(w) / max_weight)
        rect = plt.Rectangle([x - size / 2, y - size / 2],
                             size,
                             size,
                             facecolor=color,
                             edgecolor=color)
        ax1.add_patch(rect)

    ax1.set_xticks(np.arange(0, lx + 0.5, 1))
    ax1.set_yticks(np.arange(0, ly + 0.5, 1))
    ax1.set_yticklabels(row_names, fontsize=14)
    ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
    ax1.autoscale_view()
    ax1.invert_yaxis()
    ax1.set_title('Real[rho]', fontsize=14)
    # Imaginary
    ax2.patch.set_facecolor('gray')
    ax2.set_aspect('equal', 'box')
    ax2.xaxis.set_major_locator(plt.NullLocator())
    ax2.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in np.ndenumerate(dataimag):
        color = 'white' if w > 0 else 'black'
        size = np.sqrt(np.abs(w) / max_weight)
        rect = plt.Rectangle([x - size / 2, y - size / 2],
                             size,
                             size,
                             facecolor=color,
                             edgecolor=color)
        ax2.add_patch(rect)
    if np.any(dataimag != 0):
        ax2.set_xticks(np.arange(0, lx + 0.5, 1))
        ax2.set_yticks(np.arange(0, ly + 0.5, 1))
        ax2.set_yticklabels(row_names, fontsize=14)
        ax2.set_xticklabels(column_names, fontsize=14, rotation=90)
    ax2.autoscale_view()
    ax2.invert_yaxis()
    ax2.set_title('Imag[rho]', fontsize=14)
    if title:
        fig.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.close(fig)
    return fig
Exemple #10
0
def fk_reconstruct(st, slopes=[-10,10], deltaslope=0.05, slopepicking=False, smoothpicks=False, dist=0.5, maskshape=['boxcar',None],
                    method='denoise', solver="iterative",  mu=5e-2, tol=1e-12, fulloutput=False, peakinput=False, alpha=0.9):
    """
    This functions reconstructs missing signals in the f-k domain, using the original data,
    including gaps, filled with zeros, and its Mask-array (see makeMask, and slope_distribution.
    If all traces are avaiable it is a useful method to de-noise the data.
    Uses the following cost function to minimize:

            J = ||dv - T FHmtx2D Yw Dv ||^{2}_{2} + mu^2 ||Dv||^{2}_{2}

            J := Cost function
            dv:= Column-wise-ordered long vector of the 2D signal d (columns: t-domain, rows: x-domain)
            DV:= Column-wise-ordered long vector of the	f-k-spectrum D ( columns: f-domain, rows: k-domain)
            Yw := Diagonal matrix built from the column-wise-ordered long vector of Mask
            T := Sampling matrix which maps the fully sampled desired seismic data to the available samples.
                 For de-noising problems T = I (identity matrix)
            mu := Trade-off parameter between misfit and model norm


    Minimizing is done via a method of the LSMR solver, de-noising (1-2 iterations), reconstruction(8-10) iterations.
    T FHmtx2D Yw Dv will be formed to one matrix A, so at the end the equation system that will be solved has the form:

                            |   A    |		  | dv |
                            |    	 | * Dv = |    |
                            | mu * I |		  | 0  |


    :param st: Stream with missing traces, to be reconstructed or complete stream to be de-noised
    :type  st: obspy.core.stream.Stream

    :param slopes: Range of slopes to investigate for mask-function
    :type  slopes: list

    :param deltaslope: stepsize inbetween slopes.
    :type  deltaslope: float

    :param slopepicking: If True peaks of slopedistribution can be picked by hand.
    :type  slopepicking: bool

    :param smoothpicks: Determines the smoothing of the Slopedistribution, default off. If enabled the distribution ist smoothened by
                        convoluting it with a boxcar of size smoothpicks.
    :type  smoothpicks: int

    :param dist: Minimum distance inbetween maximum picks.
    :type  dist: float

    :param maskshape: maskshape[0] describes the shape of the lobes of the mask. Possible inputs are:
                 -boxcar (default)
                 -taper
                 -butterworth

                  maskshape[1] is an additional attribute to the shape of taper and butterworth, for:
                 -taper: maskshape[1] = slope of sides
                 -butterworth: maskshape[1] = number of poles

                 e.g.: maskshape['taper', 2] produces a symmetric taper with slope of side = 2.


    :type  maskshape: list

    :param method: Desired fk-method, options are 'denoise' and 'interpolate'
    :type  method: string

    :param solver: Solver used for method. Options are 'lsqr' and 'iterative'.
                   If method is 'denoise' only the iterative solver is used.
    :type  solver: string

    :param mu:	Damping parameter for the solver
    :type  mu:	float

    :param tol: Tolerance for solver to abort iteration.
    :type  tol: float

    :param fulloutput: If True, the function additionally outputs FH, dv, Dv, Ts and Yw
    :type  fulloutput: bool

    :param peakinput: Chosen peaks of the distribution, insert here if the peaks are not to be meant to recalculated
    :type  peakinput: np.ndarray

    ######  returns:

    :param st_rec: Stream with reconstructed signals on the missing traces
    :type  st_rec: obspy.core.stream.Stream

    ## if fulloutput=True

    :param st_rec: Stream with reconstructed signals on the missing traces
    :type  st_rec: obspy.core.stream.Stream

    :param FH: 2DiFFT-matrix for column-wise ordered longvector of the f-k spectrum
    :type  FH: scipy.sparse.csc.csc_matrix

    :param dv: Column-wise ordered longvector of the t-x data
    :type  dv: numpy.ndarray

    :param Dv: Column-wise ordered longvector of the f-k spectrum of dv
    :type  Dv: numpy.ndarray

    :param Ts: Sampling-matrix, which maps desired to available data
    :type  Ts: scipy.sparse.dia.dia_matrix

    :param Yw: Diagonal matrix constructed of the column-wise ordered longvector of the mask-matrix
    :type  Yw: scipy.sparse.dia.dia_matrix

    Example:
                from obspy import read as read_st
                import sipy

                stream = read_st("../data/synthetics_uniform/SUNEW.QHD")

                #Example around PP.
                stream_org = st.copy()
                d = bowpy.util.array_util.stream2array(stream_org)
                ArrayData = np.zeros((d.shape[0], 300))
                for i, trace in enumerate(d):
                    ArrayData[i,:]=trace[400:700]
                stream = bowpy.util.array_util.array2stream(ArrayData, stream_org)

                dssa = bowpy.filter.fk.fk_reconstruct(stream, mu=5e-2, method='interpolate')

                stream_ssa = bowpy.util.array_util.array2stream(dssa, stream)

                bowpy.util.fkutil.plot(stream_ssa)

    Author: S. Schneider, 2016
    Reference:	Mostafa Naghizadeh, Seismic data interpolation and de-noising in the frequency-wavenumber
                domain, 2012, GEOPHYSICS
    """

    # Prepare data.
    st_tmp 		= st.copy()
    ArrayData	= stream2array(st_tmp, normalize=False)
    ADT 		= ArrayData.copy().transpose()

    fkData 		= np.fft.fft2(ArrayData)
    fkDT 		= np.fft.fft2(ADT)

    # Look for missing Traces
    recon_list 	= []

    for i, trace in enumerate(st_tmp):
        try:
            if trace.stats.zerotrace == 'True':
                recon_list.append(i)
        except AttributeError:
            if sum(trace.data) == 0. :
                recon_list.append(i)
        except:
            continue
    print(recon_list)

    # Calculate mask-function W.
    try:
        if peakinput.any():
            peaks = peakinput
    except:
        print("Calculating slope distribution...\n")
        M, prange, peaks = slope_distribution(fkData, slopes, deltaslope, peakpick=None, mindist=dist, smoothing=smoothpicks, interactive=slopepicking)
        if fulloutput:
            kin = 'n'
            while kin in ('n', 'N'):
                plt.figure()
                plt.title('Magnitude-Distribution')
                plt.xlabel('Slope in fk-domain')
                plt.ylabel('Magnitude of slope')
                plt.plot(prange, M)
                plt.plot(peaks[0], peaks[1]/peaks[1].max()*M.max(), 'ro')
                plt.show()
                kin = raw_input("Use picks? (y/n) \n")
                if kin in ['y' , 'Y']:
                    print("Using picks, continue \n")
                elif kin in ['n', 'N']:
                    print("Don't use picks, please re-pick \n")
                    M, prange, peaks = slope_distribution(fkData, slopes, deltaslope, peakpick=None, mindist=dist, smoothing=smoothpicks, interactive=True)

    print("Creating mask function with %i significant linear events \n" % len(peaks[0]) )
    W = makeMask(fkData, peaks[0], maskshape)

    # If fulloutput is desired, a bunch of messages and user interaction appears.
    if fulloutput:
        plt.figure()
        plt.subplot(3,1,1)
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.title("fk-spectrum")
        plt.imshow(abs(np.fft.fftshift(fkData)), aspect='auto', interpolation='none')
        plt.subplot(3,1,2)
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.title("Mask-function")
        plt.imshow(np.fft.fftshift(W), aspect='auto', interpolation='none')
        plt.subplot(3,1,3)
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.title("Applied mask-function")
        plt.imshow(abs(np.fft.fftshift(W*fkData)), aspect='auto', interpolation='none')
        plt.show()
        kin = raw_input("Use Mask? (y/n) \n")
        if kin in ['y' , 'Y']:
            print("Using Mask, continue \n")
        elif kin in ['n', 'N']:
            msg="Don't use Mask, exit"
            raise IOError(msg)

    # Checking for number of iteration and reconstruction behavior.
    maxiter=None
    interpol = False
    if isinstance(method, str):
        if method in ("denoise"):
                maxiter = 2
                recon_list = []
        elif method in ("interpolate"):
                maxiter = 10
                interpol = True

    elif isinstance(method, int):
        maxiter=method

    print("maximum %i" %maxiter)
    if solver in ("lsqr", "leastsquares", "ilsmr", "iterative", "cg", "fmin"):
        pocs = False
        # To keep the order it would be better to transpose W to WT
        # but for creation of Y, WT has to be transposed again,
        # so this step can be skipped.
        Y 	= W.reshape(1,W.size)[0]
        Yw 	= sparse.diags(Y)

        # Initialize arrays for cost-function.
        dv 	= ADT.transpose().reshape(1, ADT.size)[0]
        Dv	= fkDT.transpose().reshape(1, fkDT.size)[0]

        T = np.ones((ArrayData.shape[0], ArrayData.shape[1]))
        T[recon_list] = 0.
        T = T.reshape(1, T.size)[0]

        Ts = sparse.diags(T)


        # Create sparse-matrix with iFFT operations.
        print("Creating iFFT2 operator as a %ix%i matrix ...\n" %(fkDT.shape[0]*fkDT.shape[1], fkDT.shape[0]*fkDT.shape[1]))

        FH = create_iFFT2mtx(fkDT.shape[0], fkDT.shape[1])
        print("... finished\n")

        # Create model matrix A.
        print("Creating sparse %ix%i matrix A ...\n" %(FH.shape[0], FH.shape[1]))
        A =  Ts.dot(FH.dot(Yw))
        print("Starting reconstruction...\n")

        if solver in ("lsqr", "leastsquares"):
            print(" ...using iterative least-squares solver...\n")
            x = sparse.linalg.lsqr(A, dv, mu, atol=tol, btol=tol, conlim=tol, iter_lim=maxiter)
            print("istop = %i \n" % x[1])
            print("Used iterations = %i \n" % x[2])
            print("residual Norm ||x||_2 = %f \n " % x[8])
            print("Misfit ||Ax - b||_2= %f \n" % x[4])
            print("Condition number = %f \n" % x[6])

            Dv_rec = x[0]

        elif solver in ("ilsmr", "iterative"):
            print(" ...using iterative LSMR solver...\n")
            x = sparse.linalg.lsmr(A,dv,mu, atol=tol, btol=tol, conlim=tol, maxiter=maxiter)
            print("istop = %i \n" % x[1])
            print("Used iterations = %i \n" % x[2])
            print("Misfit = %f \n " % x[3])
            print("Modelnorm = %f \n" % x[4])
            print("Condition number = %f \n" % x[5])
            print("Norm of Dv = %f \n" % x[6])
            Dv_rec = x[0]

        elif solver in ("cg"):
            A 		= Ts.dot(FH.dot(Yw))
            Ah 		= A.conjugate().transpose()
            madj 	= Ah.dot(dv)
            E 		= mu * sparse.eye(A.shape[0])
            B 		= A + E
            Binv 	= sparse.linalg.inv(B)
            x 		= sparse.linalg.cg(Binv, madj, maxiter=maxiter)
            Dv_rec 	= x[0]

        elif solver in ('fmin'):
            A 		= Ts.dot(FH.dot(Yw))
            global arg1
            global arg2
            global arg3
            arg1 = dv
            arg2 = A
            arg3 = mu

            def J(x):
                COST = np.linalg.norm(arg1 - arg2.dot(x), 2)**2. + arg3*np.linalg.norm(x,2)**2.
                return COST

            Dv_rec = sp.optimize.fmin_cg(J, x0=Dv, maxiter=10)

        data_rec = np.fft.ifft2(Dv_rec.reshape(fkData.shape)).real

    elif solver in ("pocs"):
        pocs=True
        threshold = abs( (fkData*W.astype('complex').max()) )

        for i in range(maxiter):
            data_tmp 								= ArrayData.copy()
            fkdata 									= np.fft.fft2(data_tmp) * W.astype('complex')
            fkdata[ np.where(abs(fkdata) < threshold)] 	= 0. + 0j
            threshold = threshold * alpha
            #if i % 10 == 0.:
            #	plt.imshow(abs(fkdata), aspect='auto', interpolation='none')
            #	plt.savefig("%s.png" % i)
            data_tmp 								= np.fft.ifft2(fkdata).real.copy()
            ArrayData[recon_list] 					= data_tmp[recon_list]

        data_rec = ArrayData.copy()
    else:
        print("No solver or method specified.")
        return



    if interpol:
        st_rec = st.copy()
        for i in recon_list:
            st_rec[i].data = data_rec[i,:]
            st_rec[i].stats.zerotrace = 'reconstructed'


    else:
        st_rec = array2stream(data_rec, st)

    if fulloutput and not pocs:
        return st_rec, FH, dv, Dv, Dv_rec, Ts, Yw, W
    else:
        return st_rec
def plot_warpingpaths(s1, s2, paths, path=None, filename=None, shownumbers=False):
    """Plot the warping paths matrix.

    :param s1: Series 1
    :param s2: Series 2
    :param paths: Warping paths matrix
    :param path: Path to draw (typically this is the best path)
    :param filename: Filename for the image (optional)
    :param shownumbers: Show distances also as numbers
    """
    from matplotlib import pyplot as plt
    from matplotlib import gridspec
    from matplotlib.ticker import FuncFormatter

    ratio = max(len(s1), len(s2))
    min_y = min(np.min(s1), np.min(s2))
    max_y = max(np.max(s1), np.max(s2))

    fig = plt.figure(figsize=(10, 10), frameon=True)
    gs = gridspec.GridSpec(2, 2, wspace=1, hspace=1,
                           left=0, right=1.0, bottom=0, top=1.0,
                           height_ratios=[1, 6],
                           width_ratios=[1, 6])
    max_s2_x = np.max(s2)
    max_s2_y = len(s2)
    max_s1_x = np.max(s1)
    min_s1_x = np.min(s1)
    max_s1_y = len(s1)

    if path is None:
        p = dtw.best_path(paths)
    else:
        p = path

    def format_fn2_x(tick_val, tick_pos):
        return max_s2_x - tick_val

    def format_fn2_y(tick_val, tick_pos):
        return int(max_s2_y - tick_val)

    ax0 = fig.add_subplot(gs[0, 0])
    ax0.set_axis_off()
    ax0.text(0, 0, "Dist = {:.4f}".format(paths[p[-1][0], p[-1][1]]))
    ax0.xaxis.set_major_locator(plt.NullLocator())
    ax0.yaxis.set_major_locator(plt.NullLocator())

    ax1 = fig.add_subplot(gs[0, 1:])
    ax1.set_ylim([min_y, max_y])
    ax1.set_axis_off()
    ax1.xaxis.tick_top()
    # ax1.set_aspect(0.454)
    ax1.plot(range(len(s2)), s2, ".-")
    ax1.xaxis.set_major_locator(plt.NullLocator())
    ax1.yaxis.set_major_locator(plt.NullLocator())

    ax2 = fig.add_subplot(gs[1:, 0])
    ax2.set_xlim([-max_y, -min_y])
    ax2.set_axis_off()
    # ax2.set_aspect(0.8)
    # ax2.xaxis.set_major_formatter(FuncFormatter(format_fn2_x))
    # ax2.yaxis.set_major_formatter(FuncFormatter(format_fn2_y))
    ax2.xaxis.set_major_locator(plt.NullLocator())
    ax2.yaxis.set_major_locator(plt.NullLocator())
    ax2.plot(-s1, range(max_s1_y, 0, -1), ".-")

    ax3 = fig.add_subplot(gs[1:, 1:])
    # ax3.set_aspect(1)
    ax3.matshow(paths[1:, 1:])
    # ax3.grid(which='major', color='w', linestyle='-', linewidth=0)
    # ax3.set_axis_off()
    py, px = zip(*p)
    ax3.plot(px, py, ".-", color="red")
    # ax3.xaxis.set_major_locator(plt.NullLocator())
    # ax3.yaxis.set_major_locator(plt.NullLocator())
    if shownumbers:
        for r in range(1, paths.shape[0]):
            for c in range(1, paths.shape[1]):
                ax3.text(c - 1, r - 1, "{:.2f}".format(paths[r, c]))

    gs.tight_layout(fig, pad=1.0, h_pad=1.0, w_pad=1.0)
    # fig.subplots_adjust(hspace=0, wspace=0)

    ax = fig.axes

    if filename:
        if type(filename) != str:
            filename = str(filename)
        plt.savefig(filename)
        plt.close()
        fig, ax = None, None
    return fig, ax
def show_single(image, target, classes, save):
    """
    Show the image, with or without the target
    Arguments:
        image (tensor[3, H, W]): RGB channels, value range: [0.0, 1.0]
        target (dict[tensor]): current support "boxes", "labels", "scores", "masks"
           all tensors should be of the same length, assuming N
           masks: shape=[N, H, W], dtype=torch.float
        classes (tuple): class names
        save (str): path where to save the figure
    """
    image = image.clone()
    if target and "masks" in target:
        masks = target["masks"].unsqueeze(1)
        masks = masks.repeat(1, 3, 1, 1)
        for i, m in enumerate(masks):
            f = torch.tensor(factor(i)).reshape(3, 1, 1).to(image)
            value = f * m
            image += value
            
    image = image.clamp(0, 1)
    H, W = image.shape[-2:]
    fig = plt.figure(figsize=(W / 72, H / 72))
    ax = fig.add_subplot(111)
    
    im = image.cpu().numpy()
    ax.imshow(im.transpose(1, 2, 0)) # RGB
    ax.set_title("W: {}   H: {}".format(W, H))
    ax.axis("off")

    if target:
        if "labels" in target:
            if classes is None:
                raise ValueError("'classes' should not be None when 'target' has 'labels'!")
            tags = {l: i for i, l in enumerate(tuple(set(target["labels"].tolist())))}
            
        index = 0
        if "boxes" in target:
            boxes = target["boxes"]
            boxes = xyxy2xywh(boxes).cpu().detach()
            for i, b in enumerate(boxes):
                if "labels" in target:
                    l = target["labels"][i].item()
                    index = tags[l]
                    txt = classes[l]
                    if "scores" in target:
                        s = target["scores"][i]
                        s = round(s.item() * 100)
                        txt = "{} {}%".format(txt, s)
                    ax.text(
                        b[0], b[1], txt, fontsize=10, color=factor(index),  
                        horizontalalignment="left", verticalalignment="bottom",
                        bbox=dict(boxstyle="square", fc="black", lw=1, alpha=1)
                    )
                    
                    
                rect = patches.Rectangle(b[:2], b[2], b[3], linewidth=2, edgecolor=factor(index), facecolor="none")
                ax.add_patch(rect)

    if save:
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
        plt.margins(0, 0)
        plt.savefig(save)
    plt.show()
Exemple #13
0
def show_data(indir,
              plot_type,
              print_flag,
              save_flag,
              data_only,
              vOut=None,
              pOut=None,
              tOut=None):
    pfx = 'mpet.'
    sStr = "_"
    ttl_fmt = "% = {perc:2.1f}"
    # Read in the simulation results and calcuations data
    dataFileName = "output_data.mat"
    dataFile = os.path.join(indir, dataFileName)
    data = sio.loadmat(dataFile)
    try:
        data[pfx + 'current'][0][0]
    except KeyError:
        pfx = ''
    try:
        data[pfx + "partTrodecvol0part0" + sStr + "cbar"]
    except KeyError:
        sStr = "."
    # Read in the parameters used to define the simulation
    dD_s, ndD_s = IO.read_dicts(os.path.join(indir, "input_dict_system"))
    # simulated (porous) electrodes
    Nvol = ndD_s["Nvol"]
    trodes = ndD_s["trodes"]
    dD_e = {}
    ndD_e = {}
    for trode in trodes:
        dD_e[trode], ndD_e[trode] = IO.read_dicts(
            os.path.join(indir, "input_dict_{t}".format(t=trode)))
    # Pick out some useful constants/calculated values
    k = dD_s['k']  # Boltzmann constant, J/(K Li)
    Tref = dD_s['Tref']  # Temp, K
    e = dD_s['e']  # Charge of proton, C
    F = dD_s['F']  # C/mol
    td = dD_s["td"]
    Etheta = {"a": 0.}
    for trode in trodes:
        Etheta[trode] = -(k * Tref / e) * ndD_s["phiRef"][trode]
#        Etheta[trode] = -(k*Tref/e) * ndD_e[trode]["muR_ref"]
    Vstd = Etheta["c"] - Etheta["a"]
    Nvol = ndD_s["Nvol"]
    Npart = ndD_s["Npart"]
    psd_len = dD_s["psd_len"]
    # Discretization (and associated porosity)
    Lfac = 1e6
    Lunit = r"$\mu$m"
    dxc = ndD_s["L"]["c"] / Nvol["c"]
    dxvec = np.array(Nvol["c"] * [dxc])
    porosvec = np.array(Nvol["c"] * [ndD_s["poros"]["c"]])
    cellsvec = dxc * np.arange(Nvol["c"]) + dxc / 2.
    if Nvol["s"]:
        dxs = ndD_s["L"]["s"] / Nvol["s"]
        dxvec_s = np.array(Nvol["s"] * [dxs])
        dxvec = np.hstack((dxvec_s, dxvec))
        poros_s = np.array(Nvol["s"] * [ndD_s["poros"]["s"]])
        porosvec = np.hstack((poros_s, porosvec))
        cellsvec += dD_s["L"]["s"] / dD_s["L"]["c"]
        cellsvec_s = dxs * np.arange(Nvol["s"]) + dxs / 2.
        cellsvec = np.hstack((cellsvec_s, cellsvec))
    if "a" in trodes:
        dxa = ndD_s["L"]["a"] / Nvol["a"]
        dxvec_a = np.array(Nvol["a"] * [dxa])
        dxvec = np.hstack((dxvec_a, dxvec))
        poros_a = np.array(Nvol["a"] * [ndD_s["poros"]["a"]])
        porosvec = np.hstack((poros_a, porosvec))
        cellsvec += dD_s["L"]["a"] / dD_s["L"]["c"]
        cellsvec_a = dxa * np.arange(Nvol["a"]) + dxa / 2.
        cellsvec = np.hstack((cellsvec_a, cellsvec))
    cellsvec *= dD_s["Lref"] * Lfac
    facesvec = np.insert(np.cumsum(dxvec), 0, 0.) * dD_s["Lref"] * Lfac
    # Extract the reported simulation times
    times = data[pfx + 'phi_applied_times'][0]
    numtimes = len(times)
    tmin = np.min(times)
    tmax = np.max(times)
    # Simulation type
    profileType = ndD_s['profileType']
    # Colors for plotting concentrations
    to_yellow = 0.3
    to_red = 0.7
    scl = 1.0  # static
    #    scl = 1.2  # movies
    figsize = (scl * 6, scl * 4)

    # Print relevant simulation info
    if print_flag:
        print("profileType:", profileType)
        #        for i in trodes:
        #            print "PSD_{l}:".format(l=l)
        #            print psd_len[l].transpose()
        #            print "Actual psd_mean [nm]:", np.mean(psd_len[l])
        #            print "Actual psd_stddev [nm]:", np.std(psd_len[l])
        print("Cell structure:")
        print(("porous anode | " if Nvol["a"] else "flat anode | ") +
              ("sep | " if Nvol["s"] else "") + "porous cathode")
        if Nvol["a"]:
            print("capacity ratio cathode:anode, 'z':", ndD_s["z"])
        for trode in trodes:
            print("solidType_{t}:".format(t=trode), ndD_e[trode]['type'])
            print("solidShape_{t}".format(t=trode), ndD_e[trode]['shape'])
            print("rxnType_{t}:".format(t=trode), ndD_e[trode]['rxnType'])
        if profileType == "CC":
            print("C_rate:", dD_s['Crate'])
            print("current:", dD_s['currset'], "A/m^2")
        else:  # CV
            print("Vset:", dD_s['Vset'])
        print("Specified psd_mean, c [{unit}]:".format(unit=Lunit),
              np.array(dD_s['psd_mean']["c"]) * Lfac)
        print("Specified psd_stddev, c [{unit}]:".format(unit=Lunit),
              np.array(dD_s['psd_stddev']["c"]) * Lfac)
        #        print "reg sln params:"
        #        print ndD["Omga"]
        print("ndim B_c:", ndD_e["c"]["B"])
        if Nvol["s"]:
            print("Nvol_s:", Nvol["s"])
        print("Nvol_c:", Nvol["c"])
        if Nvol["a"]:
            print("Nvol_a:", Nvol["a"])
        print("Npart_c:", Npart["c"])
        if Nvol["a"]:
            print("Npart_a:", Npart["a"])
        print("Dp [m^2/s]:", dD_s['Dp'])
        print("Dm [m^2/s]:", dD_s['Dm'])
        print("Damb [m^2/s]:", dD_s['Damb'])
        print("td [s]:", dD_s["td"])
        for trode in trodes:
            print("k0_{t} [A/m^2]:".format(t=trode), dD_e[trode]['k0'])
            rxnType = ndD_e[trode]['rxnType']
            if rxnType == "BV":
                print("alpha_" + trode + ":", ndD_e[trode]['alpha'])
            elif rxnType in ["Marcus", "MHC"]:
                print("lambda_" + trode + "/(kTref):", ndD_e[trode]["lambda"])
            if ndD_s['simBulkCond'][trode]:
                print(trode + " bulk conductivity loss: Yes -- " +
                      "sigma_s [S/m]: " + str(dD_s['sigma_s'][trode]))
            else:
                print(trode + " bulk conductivity loss: No")
            try:
                simSurfCond = ndD_e[trode]['simSurfCond']
                if simSurfCond:
                    print(trode + " surface conductivity loss: Yes -- " +
                          "dim_scond [S]: " + str(dD_e[trode]['scond']))
                else:
                    print(trode + " surface conductivity loss: No")
            except:
                pass
#            if ndD['simSurfCond'][l]:
#                print (l + " surface conductivity loss: Yes -- " +
#                        "dim_scond [S]: " + str(dD['scond'][l]))
#            else:
#                print l + " surface conductivity loss: No"

    if plot_type in ["params"]:
        return ndD_s, dD_s, ndD_e, dD_e
    if plot_type in ["discData"]:
        return cellsvec / Lfac, facesvec / Lfac

    # Plot voltage profile
    if plot_type in ["v", "vt"]:
        voltage = (Vstd - (k * Tref / e) * data[pfx + 'phi_applied'][0])
        ffvec = data[pfx + 'ffrac_c'][0]
        fig, ax = plt.subplots(figsize=figsize)
        if plot_type == "v":
            if data_only:
                plt.close(fig)
                return ffvec, voltage
            ax.plot(ffvec, voltage)
            xmin = 0.
            xmax = 1.
            ax.set_xlim((xmin, xmax))
            ax.set_xlabel("Cathode Filling Fraction [dimensionless]")
        elif plot_type == "vt":
            if data_only:
                plt.close(fig)
                return times * td, voltage
            ax.plot(times * td, voltage)
            ax.set_xlabel("Time [s]")
        ax.set_ylabel("Voltage [V]")
        if save_flag:
            fig.savefig("mpet_v.pdf", bbox_inches="tight")
        return fig, ax

    # Plot surface conc.
    if plot_type[:-2] in ["surf"]:
        trode = plot_type[-1]
        str_base = (
            pfx +
            "partTrode{trode}vol{{vInd}}part{{pInd}}".format(trode=trode) +
            sStr + "c")
        if data_only:
            sol_str = str_base.format(pInd=pOut, vInd=vOut)
            datay = data[sol_str][:, -1]
            return times * td, datay
        fig, ax = plt.subplots(Npart[trode],
                               Nvol[trode],
                               squeeze=False,
                               sharey=True,
                               figsize=figsize)
        ylim = (0, 1.01)
        datax = times
        for pInd in range(Npart[trode]):
            for vInd in range(Nvol[trode]):
                sol_str = str_base.format(pInd=pInd, vInd=vInd)
                # Remove axis ticks
                ax[pInd, vInd].xaxis.set_major_locator(plt.NullLocator())
                datay = data[sol_str][:, -1]
                line, = ax[pInd, vInd].plot(times, datay)
        return fig, ax

    # Plot SoC profile
    if plot_type[:-2] in ["soc"]:
        trode = plot_type[-1]
        ffvec = data[pfx + 'ffrac_{trode}'.format(trode=trode)][0]
        if data_only:
            return times * td, ffvec
        fig, ax = plt.subplots(figsize=figsize)
        print(ffvec[-1])
        ax.plot(times * td, ffvec)
        xmin = np.min(ffvec)
        xmax = np.max(ffvec)
        ax.set_ylim((0, 1.05))
        ax.set_xlabel("Time [s]")
        ax.set_ylabel("Filling Fraciton [dimless]")
        if save_flag:
            fig.savefig("mpet_soc.pdf", bbox_inches="tight")
        return fig, ax

    # Check to make sure mass is conserved in elyte
    if plot_type == "elytecons":
        fig, ax = plt.subplots(figsize=figsize)
        eps = 1e-2
        ymin = 1 - eps
        ymax = 1 + eps
        #        ax.set_ylim((ymin, ymax))
        ax.set_ylabel('Avg. Concentration of electrolyte [nondim]')
        sep = pfx + 'c_lyte_s'
        anode = pfx + 'c_lyte_a'
        cath = pfx + 'c_lyte_c'
        ax.set_xlabel('Time [s]')
        cvec = data[cath]
        if Nvol["s"]:
            cvec_s = data[sep]
            cvec = np.hstack((cvec_s, cvec))
        if "a" in trodes:
            cvec_a = data[anode]
            cvec = np.hstack((cvec_a, cvec))
        cavg = np.sum(porosvec * dxvec * cvec, axis=1) / np.sum(
            porosvec * dxvec)
        if data_only:
            plt.close(fig)
            return times * td, cavg
        np.set_printoptions(precision=8)
        print(cavg)
        ax.plot(times * td, cavg)
        return fig, ax

    # Plot current profile
    if plot_type == "curr":
        current = data[pfx + "current"][0] * 3600 / td
        ffvec = data[pfx + 'ffrac_c'][0]
        if data_only:
            return times * td, current
        fig, ax = plt.subplots(figsize=figsize)
        ax.plot(times * td, current)
        xmin = np.min(ffvec)
        xmax = np.max(ffvec)
        ax.set_xlabel("Time [s]")
        ax.set_ylabel("Current [C-rate]")
        if save_flag:
            fig.savefig("mpet_current.png", bbox_inches="tight")
        return fig, ax

    # Plot electrolyte concentration or potential
    elif plot_type in [
            "elytec", "elytep", "elytecf", "elytepf", "elytei", "elyteif",
            "elytedivi", "elytedivif"
    ]:
        fplot = (True if plot_type[-1] == "f" else False)
        t0ind = (0 if not fplot else -1)
        mpl.animation.Animation._blit_draw = _blit_draw
        datax = cellsvec
        c_sep, p_sep = pfx + 'c_lyte_s', pfx + 'phi_lyte_s'
        c_anode, p_anode = pfx + 'c_lyte_a', pfx + 'phi_lyte_a'
        c_cath, p_cath = pfx + 'c_lyte_c', pfx + 'phi_lyte_c'
        datay_c = data[c_cath]
        datay_p = data[p_cath]
        L_c = dD_s['L']["c"] * Lfac
        Ltot = L_c
        if Nvol["s"]:
            datay_s_c = data[c_sep]
            datay_s_p = data[p_sep]
            datay_c = np.hstack((datay_s_c, datay_c))
            datay_p = np.hstack((datay_s_p, datay_p))
            L_s = dD_s['L']["s"] * Lfac
            Ltot += L_s
        else:
            L_s = 0
        if "a" in trodes:
            datay_a_c = data[c_anode]
            datay_a_p = data[p_anode]
            datay_c = np.hstack((datay_a_c, datay_c))
            datay_p = np.hstack((datay_a_p, datay_p))
            L_a = dD_s['L']["a"] * Lfac
            Ltot += L_a
        else:
            L_a = 0
        xmin = 0
        xmax = Ltot
        if plot_type in ["elytec", "elytecf"]:
            ymin = 0
            ymax = 2.2
            ylbl = 'Concentration of electrolyte [M]'
            datay = datay_c * dD_s["cref"] / 1000.
        elif plot_type in ["elytep", "elytepf"]:
            ymin = -50
            ymax = 50
            ylbl = 'Potential of electrolyte [V]'
            datay = datay_p * (k * Tref / e) - Vstd
        elif plot_type in ["elytei", "elyteif", "elytedivi", "elytedivif"]:
            cGP_L, pGP_L = data["c_lyteGP_L"], data["phi_lyteGP_L"]
            cmat = np.hstack((cGP_L.T, datay_c, datay_c[:, -1].reshape(
                (numtimes, 1))))
            pmat = np.hstack((pGP_L.T, datay_p, datay_p[:, -1].reshape(
                (numtimes, 1))))
            disc = geom.get_elyte_disc(Nvol, ndD_s["L"], ndD_s["poros"],
                                       ndD_s["BruggExp"])
            i_edges = np.zeros((numtimes, len(facesvec)))
            for tInd in range(numtimes):
                i_edges[tInd, :] = mod_cell.get_lyte_internal_fluxes(
                    cmat[tInd, :], pmat[tInd, :], disc["dxd1"],
                    disc["eps_o_tau_edges"], ndD_s)[1]
            if plot_type in ["elytei", "elyteif"]:
                ylbl = r'Current density of electrolyte [A/m$^2$]'
                datax = facesvec
                datay = i_edges * (F * dD_s["cref"] * dD_s["Dref"] /
                                   dD_s["Lref"])
            elif plot_type in ["elytedivi", "elytedivif"]:
                ylbl = r'Divergence of electrolyte current density [A/m$^3$]'
                datax = cellsvec
                datay = np.diff(i_edges, axis=1) / disc["dxd2"]
                datay *= (F * dD_s["cref"] * dD_s["Dref"] / dD_s["Lref"]**2)
        if fplot:
            datay = datay[t0ind]
        if data_only:
            return datax, datay, L_a, L_s
        dataMin, dataMax = np.min(datay), np.max(datay)
        dataRange = dataMax - dataMin
        ymin = max(0, dataMin - 0.05 * dataRange)
        ymax = dataMax + 0.05 * dataRange
        fig, ax = plt.subplots(figsize=figsize)
        ax.set_xlabel('Battery Position [{unit}]'.format(unit=Lunit))
        ax.set_ylabel(ylbl)
        ttl = ax.text(0.5,
                      1.05,
                      ttl_fmt.format(perc=0),
                      transform=ax.transAxes,
                      verticalalignment="center",
                      horizontalalignment="center")
        ax.set_ylim((ymin, ymax))
        ax.set_xlim((xmin, xmax))
        # returns tuble of line objects, thus comma
        if fplot:
            line1, = ax.plot(datax, datay, '-')
        else:
            line1, = ax.plot(datax, datay[t0ind, :], '-')
        ax.axvline(x=L_a, linestyle='--', color='g')
        ax.axvline(x=(L_a + L_s), linestyle='--', color='g')
        if fplot:
            print("time =", times[t0ind] * td, "s")
            if save_flag:
                fig.savefig("mpet_{pt}.png".format(pt=plot_type),
                            bbox_inches="tight")
            return fig, ax

        def init():
            line1.set_ydata(np.ma.array(datax, mask=True))
            ttl.set_text('')
            return line1, ttl

        def animate(tind):
            line1.set_ydata(datay[tind])
            t_current = times[tind]
            tfrac = (t_current - tmin) / (tmax - tmin) * 100
            ttl.set_text(ttl_fmt.format(perc=tfrac))
            return line1, ttl

    # Plot solid particle-average concentrations
    elif plot_type[:-2] in ["cbarLine", "dcbardtLine"]:
        trode = plot_type[-1]
        fig, ax = plt.subplots(Npart[trode],
                               Nvol[trode],
                               squeeze=False,
                               sharey=True,
                               figsize=figsize)
        partStr = "partTrode{trode}vol{{vInd}}part{{pInd}}".format(
            trode=trode) + sStr
        type2c = False
        if ndD_e[trode]["type"] in ndD_s["1varTypes"]:
            if plot_type[:-2] in ["cbarLine"]:
                str_base = pfx + partStr + "cbar"
            elif plot_type[:-2] in ["dcbardtLine"]:
                str_base = pfx + partStr + "dcbardt"
        elif ndD_e[trode]["type"] in ndD_s["2varTypes"]:
            type2c = True
            if plot_type[:-2] in ["cbarLine"]:
                str1_base = pfx + partStr + "c1bar"
                str2_base = pfx + partStr + "c2bar"
            elif plot_type[:-2] in ["dcbardtLine"]:
                str1_base = pfx + partStr + "dc1bardt"
                str2_base = pfx + partStr + "dc2bardt"
        ylim = (0, 1.01)
        datax = times * td
        if data_only:
            plt.close(fig)
            if type2c:
                sol1_str = str1_base.format(pInd=pOut, vInd=vOut)
                sol2_str = str2_base.format(pInd=pOut, vInd=vOut)
                datay1 = data[sol1_str][0]
                datay2 = data[sol2_str][0]
                datay = (datay1, datay2)
            else:
                sol_str = str_base.format(pInd=pOut, vInd=vOut)
                datay = data[sol_str][0]
            return datax, datay
        xLblNCutoff = 4
        xLbl = "Time [s]"
        yLbl = "Particle Average Filling Fraction"
        for pInd in range(Npart[trode]):
            for vInd in range(Nvol[trode]):
                if type2c:
                    sol1_str = str1_base.format(pInd=pInd, vInd=vInd)
                    sol2_str = str2_base.format(pInd=pInd, vInd=vInd)
                    if Nvol[trode] > xLblNCutoff:
                        # Remove axis ticks
                        ax[pInd,
                           vInd].xaxis.set_major_locator(plt.NullLocator())
                    else:
                        ax[pInd, vInd].set_xlabel(xLbl)
                        ax[pInd, vInd].set_ylabel(yLbl)
                    datay1 = data[sol1_str][0]
                    datay2 = data[sol2_str][0]
                    line1, = ax[pInd, vInd].plot(times, datay1)
                    line2, = ax[pInd, vInd].plot(times, datay2)
                else:
                    sol_str = str_base.format(pInd=pInd, vInd=vInd)
                    if Nvol[trode] > xLblNCutoff:
                        # Remove axis ticks
                        ax[pInd,
                           vInd].xaxis.set_major_locator(plt.NullLocator())
                    else:
                        ax[pInd, vInd].set_xlabel(xLbl)
                        ax[pInd, vInd].set_ylabel(yLbl)
                    datay = data[sol_str][0]
                    line, = ax[pInd, vInd].plot(times, datay)
        return fig, ax

    # Plot all solid concentrations or potentials
    elif plot_type[:-2] in ["csld", "musld"]:
        timettl = False  # Plot the current simulation time as title
        # Plot title in seconds
        ttlscl, ttlunit = 1, "s"
        # For example, to plot title in hours:
        # ttlscl, ttlunit = 1./3600, "hr"
        save_shot = False
        if save_shot:
            t0ind = 300
            print("Time at screenshot: {ts} s".format(ts=times[t0ind] * td))
        else:
            t0ind = 0
        trode = plot_type[-1]
        if plot_type[0] == "c":
            plt_cavg = True
        else:
            plt_cavg = False
        plt_legend = True
        plt_axlabels = True
        if ndD_e[trode]["type"] in ndD_s["1varTypes"]:
            type2c = False
        elif ndD_e[trode]["type"] in ndD_s["2varTypes"]:
            type2c = True
        Nv, Np = Nvol[trode], Npart[trode]
        partStr = "partTrode{trode}vol{vInd}part{pInd}" + sStr
        fig, ax = plt.subplots(Np,
                               Nv,
                               squeeze=False,
                               sharey=True,
                               figsize=figsize)
        if not type2c:
            cstr_base = pfx + partStr + "c"
            cbarstr_base = pfx + partStr + "cbar"
            cstr = np.empty((Np, Nv), dtype=object)
            cbarstr = np.empty((Np, Nv), dtype=object)
            lines = np.empty((Np, Nv), dtype=object)
        elif type2c:
            c1str_base = pfx + partStr + "c1"
            c2str_base = pfx + partStr + "c2"
            c1barstr_base = pfx + partStr + "c1bar"
            c2barstr_base = pfx + partStr + "c2bar"
            c1str = np.empty((Np, Nv), dtype=object)
            c2str = np.empty((Np, Nv), dtype=object)
            c1barstr = np.empty((Np, Nv), dtype=object)
            c2barstr = np.empty((Np, Nv), dtype=object)
            lines1 = np.empty((Np, Nv), dtype=object)
            lines2 = np.empty((Np, Nv), dtype=object)
            lines3 = np.empty((Np, Nv), dtype=object)
        lens = np.zeros((Np, Nv))
        if data_only:
            print("tInd_{}".format(tOut), "time =", times[tOut] * td, "s")
            lenval = psd_len[trode][vOut, pOut]
            if type2c:
                c1str = c1str_base.format(trode=trode, pInd=pOut, vInd=vOut)
                c2str = c2str_base.format(trode=trode, pInd=pOut, vInd=vOut)
                c1barstr = c1barstr_base.format(trode=trode,
                                                pInd=pOut,
                                                vInd=vOut)
                c2barstr = c2barstr_base.format(trode=trode,
                                                pInd=pOut,
                                                vInd=vOut)
                datay1 = data[c1str[pOut, vOut]][tOut]
                datay2 = data[c2str[pOut, vOut]][tOut]
                if plot_type[:-2] in ["musld"]:
                    c1bar = data[c1barstr[pOut, vOut]][0][tOut]
                    c2bar = data[c2barstr[pOut, vOut]][0][tOut]
                    muRfunc = props_am.muRfuncs(
                        ndD_s["T"], ndD_e[trode]["indvPart"][vOut,
                                                             pOut]).muRfunc
                    datay1, datay2 = muRfunc((datay1, datay2), (c1bar, c2bar),
                                             ndD_e[trode]["muR_ref"])[0]
                datay = (datay1, datay2)
                numy = len(datay1)
            else:
                cstr = cstr_base.format(trode=trode, pInd=pOut, vInd=vOut)
                cbarstr = cbarstr_base.format(trode=trode,
                                              pInd=pOut,
                                              vInd=vOut)
                datay = data[cstr][tOut]
                if plot_type[:-2] in ["musld"]:
                    cbar = data[cbarstr[pOut, vOut]][0][tOut]
                    muRfunc = props_am.muRfuncs(
                        ndD_s["T"], ndD_e[trode]["indvPart"][vOut,
                                                             pOut]).muRfunc
                    datay = muRfunc(datay, cbar, ndD_e[trode]["muR_ref"])[0]
                numy = len(datay)
            datax = np.linspace(0, lenval * Lfac, numy)
            plt.close(fig)
            return datax, datay
        if plot_type[:-2] in ["csld"]:
            ylim = (0, 1.01)
        elif plot_type[:-2] in ["musld"]:
            ylim = (-4, 4)
        for pInd in range(Np):
            for vInd in range(Nv):
                lens[pInd, vInd] = psd_len[trode][vInd, pInd]
                if type2c:
                    c1str[pInd, vInd] = c1str_base.format(trode=trode,
                                                          pInd=pInd,
                                                          vInd=vInd)
                    c2str[pInd, vInd] = c2str_base.format(trode=trode,
                                                          pInd=pInd,
                                                          vInd=vInd)
                    c1barstr[pInd, vInd] = c1barstr_base.format(trode=trode,
                                                                pInd=pInd,
                                                                vInd=vInd)
                    c2barstr[pInd, vInd] = c2barstr_base.format(trode=trode,
                                                                pInd=pInd,
                                                                vInd=vInd)
                    datay1 = data[c1str[pInd, vInd]][t0ind]
                    datay2 = data[c2str[pInd, vInd]][t0ind]
                    datay3 = 0.5 * (datay1 + datay2)
                    lbl1, lbl2 = r"$\widetilde{c}_1$", r"$\widetilde{c}_2$"
                    lbl3 = r"$\overline{c}$"
                    if plot_type[:-2] in ["musld"]:
                        lbl1, lbl2 = r"$\mu_1/k_\mathrm{B}T$", r"$\mu_2/k_\mathrm{B}T$"
                        c1bar = data[c1barstr[pInd, vInd]][0][t0ind]
                        c2bar = data[c2barstr[pInd, vInd]][0][t0ind]
                        muRfunc = props_am.muRfuncs(
                            ndD_s["T"], ndD_e[trode]["indvPart"][vInd,
                                                                 pInd]).muRfunc
                        datay1, datay2 = muRfunc((datay1, datay2),
                                                 (c1bar, c2bar),
                                                 ndD_e[trode]["muR_ref"])[0]
                    numy = len(datay1)
                    datax = np.linspace(0, lens[pInd, vInd] * Lfac, numy)
                    line1, = ax[pInd, vInd].plot(datax, datay1, label=lbl1)
                    line2, = ax[pInd, vInd].plot(datax, datay2, label=lbl2)
                    if plt_cavg:
                        line3, = ax[pInd, vInd].plot(datax,
                                                     datay3,
                                                     '--',
                                                     label=lbl3)
                        lines3[pInd, vInd] = line3
                    lines1[pInd, vInd] = line1
                    lines2[pInd, vInd] = line2
                else:
                    cstr[pInd, vInd] = cstr_base.format(trode=trode,
                                                        pInd=pInd,
                                                        vInd=vInd)
                    cbarstr[pInd, vInd] = cbarstr_base.format(trode=trode,
                                                              pInd=pInd,
                                                              vInd=vInd)
                    datay = data[cstr[pInd, vInd]][t0ind]
                    if plot_type[:-2] in ["musld"]:
                        cbar = np.array(data[cbarstr[pInd, vInd]][0][t0ind])
                        muRfunc = props_am.muRfuncs(
                            ndD_s["T"], ndD_e[trode]["indvPart"][vInd,
                                                                 pInd]).muRfunc
                        datay = muRfunc(datay, cbar,
                                        ndD_e[trode]["muR_ref"])[0]
                    numy = len(datay)
                    datax = np.linspace(0, lens[pInd, vInd] * Lfac, numy)
                    line, = ax[pInd, vInd].plot(datax, datay)
                    lines[pInd, vInd] = line
                ax[pInd, vInd].set_ylim(ylim)
                ax[pInd, vInd].set_xlim((0, lens[pInd, vInd] * Lfac))
                if plt_legend:
                    ax[pInd, vInd].legend(loc="best")
                if plt_axlabels:
                    ax[pInd,
                       vInd].set_xlabel(r"$r$ [{Lunit}]".format(Lunit=Lunit))
                    if plot_type[0] == "c":
                        ax[pInd, vInd].set_ylabel(r"$\widetilde{{c}}$")
                    elif plot_type[:2] == "mu":
                        ax[pInd, vInd].set_ylabel(r"$\mu/k_\mathrm{B}T$")
                if timettl:
                    mpl.animation.Animation._blit_draw = _blit_draw
                    ttl = ax[pInd,
                             vInd].text(0.5,
                                        1.04,
                                        "t = {tval:3.3f} {ttlu}".format(
                                            tval=times[t0ind] * td * ttlscl,
                                            ttlu=ttlunit),
                                        verticalalignment="center",
                                        horizontalalignment="center",
                                        transform=ax[pInd, vInd].transAxes)
        if save_shot:
            fig.savefig("mpet_{pt}.pdf".format(pt=plot_type),
                        bbox_inches="tight")

        def init():
            toblit = []
            for pInd in range(Npart[trode]):
                for vInd in range(Nvol[trode]):
                    if type2c:
                        numy = len(data[c1str[pInd, vInd]][t0ind])
                        maskTmp = np.zeros(numy)
                        lines1[pInd,
                               vInd].set_ydata(np.ma.array(maskTmp, mask=True))
                        lines2[pInd,
                               vInd].set_ydata(np.ma.array(maskTmp, mask=True))
                        lines_local = np.vstack((lines1, lines2))
                        if plt_cavg:
                            lines3[pInd, vInd].set_ydata(
                                np.ma.array(maskTmp, mask=True))
                            lines_local = np.vstack((lines_local, lines3))
                    else:
                        numy = len(data[cstr[pInd, vInd]][t0ind])
                        maskTmp = np.zeros(numy)
                        lines[pInd,
                              vInd].set_ydata(np.ma.array(maskTmp, mask=True))
                        lines_local = lines.copy()
                    toblit.extend(lines_local.reshape(-1))
                    if timettl:
                        ttl.set_text("")
                        toblit.extend([ttl])
            return tuple(toblit)

        def animate(tind):
            toblit = []
            for pInd in range(Npart[trode]):
                for vInd in range(Nvol[trode]):
                    if type2c:
                        datay1 = data[c1str[pInd, vInd]][tind]
                        datay2 = data[c2str[pInd, vInd]][tind]
                        datay3 = 0.5 * (datay1 + datay2)
                        if plot_type[:-2] in ["musld"]:
                            c1bar = data[c1barstr[pInd, vInd]][0][tind]
                            c2bar = data[c2barstr[pInd, vInd]][0][tind]
                            muRfunc = props_am.muRfuncs(
                                ndD_s["T"],
                                ndD_e[trode]["indvPart"][vInd, pInd]).muRfunc
                            datay1, datay2 = muRfunc(
                                (datay1, datay2), (c1bar, c2bar),
                                ndD_e[trode]["muR_ref"])[0]
                        lines1[pInd, vInd].set_ydata(datay1)
                        lines2[pInd, vInd].set_ydata(datay2)
                        lines_local = np.vstack((lines1, lines2))
                        if plt_cavg:
                            lines3[pInd, vInd].set_ydata(datay3)
                            lines_local = np.vstack((lines_local, lines3))
                    else:
                        datay = data[cstr[pInd, vInd]][tind]
                        if plot_type[:-2] in ["musld"]:
                            cbar = data[cbarstr[pInd, vInd]][0][tind]
                            muRfunc = props_am.muRfuncs(
                                ndD_s["T"],
                                ndD_e[trode]["indvPart"][vInd, pInd]).muRfunc
                            datay = muRfunc(datay, cbar,
                                            ndD_e[trode]["muR_ref"])[0]
                        lines[pInd, vInd].set_ydata(datay)
                        lines_local = lines.copy()
                    toblit.extend(lines_local.reshape(-1))
                    if timettl:
                        ttl.set_text("t = {tval:3.3f} {ttlu}".format(
                            tval=times[tind] * td * ttlscl, ttlu=ttlunit))
                        toblit.extend([ttl])
            return tuple(toblit)

    # Plot average solid concentrations
    elif plot_type in ["cbar_c", "cbar_a", "cbar_full"]:
        if plot_type[-4:] == "full":
            trvec = ["a", "c"]
        elif plot_type[-1] == "a":
            trvec = ["a"]
        else:
            trvec = ["c"]
        dataCbar = {}
        for trode in trodes:
            dataCbar[trode] = np.zeros((numtimes, Nvol[trode], Npart[trode]))
            for tInd in range(numtimes):
                for vInd in range(Nvol[trode]):
                    for pInd in range(Npart[trode]):
                        dataStr = (pfx +
                                   "partTrode{t}vol{vInd}part{pInd}".format(
                                       t=trode, vInd=vInd, pInd=pInd) + sStr +
                                   "cbar")
                        dataCbar[trode][tInd, vInd,
                                        pInd] = (data[dataStr][0][tInd])
        if data_only:
            return dataCbar
        # Set up colors.
        # Define if you want smooth or discrete color changes
        # Option: "smooth" or "discrete"
        color_changes = "discrete"
        #        color_changes = "smooth"
        # Discrete color changes:
        if color_changes == "discrete":
            # Make a discrete colormap that goes from green to yellow
            # to red instantaneously
            cdict = {
                "red": [(0.0, 0.0, 0.0), (to_yellow, 0.0, 1.0),
                        (1.0, 1.0, 1.0)],
                "green": [(0.0, 0.502, 0.502), (to_yellow, 0.502, 1.0),
                          (to_red, 1.0, 0.0), (1.0, 0.0, 0.0)],
                "blue": [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)]
            }
            cmap = mpl.colors.LinearSegmentedColormap("discrete", cdict)
        # Smooth colormap changes:
        if color_changes == "smooth":
            # generated with colormap.org
            cmaps = np.load("colormaps_custom.npz")
            cmap_data = cmaps["GnYlRd_3"]
            cmap = mpl.colors.ListedColormap(cmap_data / 255.)

        # Implement hack to be able to animate title
        mpl.animation.Animation._blit_draw = _blit_draw
        size_frac_min = 0.10
        fig, axs = plt.subplots(1, len(trvec), squeeze=False, figsize=figsize)
        ttlx = 0.5 if len(trvec) < 2 else 1.1
        ttl = axs[0, 0].text(ttlx,
                             1.05,
                             ttl_fmt.format(perc=0),
                             transform=axs[0, 0].transAxes,
                             verticalalignment="center",
                             horizontalalignment="center")
        collection = np.empty(len(trvec), dtype=object)
        for indx, trode in enumerate(trvec):
            ax = axs[0, indx]
            # Get particle sizes (and max size) (length-based)
            lens = psd_len[trode]
            len_max = np.max(lens)
            len_min = np.min(lens)
            ax.patch.set_facecolor('white')
            # Don't stretch axes to fit figure -- keep 1:1 x:y ratio.
            ax.set_aspect('equal', 'box')
            # Don't show axis ticks
            ax.xaxis.set_major_locator(plt.NullLocator())
            ax.yaxis.set_major_locator(plt.NullLocator())
            ax.set_xlim(0, 1.)
            ax.set_ylim(0, float(Npart[trode]) / Nvol[trode])
            # Label parts of the figure
            #            ylft = ax.text(-0.07, 0.5, "Separator",
            #                    transform=ax.transAxes, rotation=90,
            #                    verticalalignment="center",
            #                    horizontalalignment="center")
            #            yrht = ax.text(1.09, 0.5, "Current Collector",
            #                    transform=ax.transAxes, rotation=90,
            #                    verticalalignment="center",
            #                    horizontalalignment="center")
            #            xbtm = ax.text(.50, -0.05, "Electrode Depth -->",
            #                    transform=ax.transAxes, rotation=0,
            #                    verticalalignment="center",
            #                    horizontalalignment="center")
            # Geometric parameters for placing the rectangles on the axes
            spacing = 1.0 / Nvol[trode]
            size_fracs = 0.4 * np.ones((Nvol[trode], Npart[trode]))
            if len_max != len_min:
                size_fracs = (lens - len_min) / (len_max - len_min)
            sizes = (size_fracs *
                     (1 - size_frac_min) + size_frac_min) / Nvol[trode]
            # Create rectangle "patches" to add to figure axes.
            rects = np.empty((Nvol[trode], Npart[trode]), dtype=object)
            color = 'green'  # value is irrelevant -- it will be animated
            for (vInd, pInd), c in np.ndenumerate(sizes):
                size = sizes[vInd, pInd]
                center = np.array(
                    [spacing * (vInd + 0.5), spacing * (pInd + 0.5)])
                bottom_left = center - size / 2
                rects[vInd, pInd] = plt.Rectangle(bottom_left,
                                                  size,
                                                  size,
                                                  color=color)
            # Create a group of rectange "patches" from the rects array
            collection[indx] = mcollect.PatchCollection(rects.reshape(-1))
            # Put them on the axes
            ax.add_collection(collection[indx])
        # Have a "background" image of rectanges representing the
        # initial state of the system.

        def init():
            for indx, trode in enumerate(trvec):
                cbar_mat = dataCbar[trode][0, :, :]
                colors = cmap(cbar_mat.reshape(-1))
                collection[indx].set_color(colors)
                ttl.set_text('')
            out = [collection[i] for i in range(len(collection))]
            out.append(ttl)
            out = tuple(out)
            return out

        def animate(tind):
            for indx, trode in enumerate(trvec):
                cbar_mat = dataCbar[trode][tind, :, :]
                colors = cmap(cbar_mat.reshape(-1))
                collection[indx].set_color(colors)
            t_current = times[tind]
            tfrac = (t_current - tmin) / (tmax - tmin) * 100
            ttl.set_text(ttl_fmt.format(perc=tfrac))
            out = [collection[i] for i in range(len(collection))]
            out.append(ttl)
            out = tuple(out)
            return out

    # Plot cathode potential
    elif plot_type[0:5] in ["bulkp"]:
        trode = plot_type[-1]
        fplot = (True if plot_type[-3] == "f" else False)
        t0ind = (0 if not fplot else -1)
        mpl.animation.Animation._blit_draw = _blit_draw
        fig, ax = plt.subplots(figsize=figsize)
        ax.set_xlabel('Position in electrode [{unit}]'.format(unit=Lunit))
        ax.set_ylabel('Potential of cathode [nondim]')
        ttl = ax.text(0.5,
                      1.05,
                      ttl_fmt.format(perc=0),
                      transform=ax.transAxes,
                      verticalalignment="center",
                      horizontalalignment="center")
        bulkp = pfx + 'phi_bulk_{trode}'.format(trode=trode)
        datay = data[bulkp]
        ymin = np.min(datay) - 0.2
        ymax = np.max(datay) + 0.2
        if trode == "a":
            datax = cellsvec[:Nvol["a"]]
        elif trode == "c":
            datax = cellsvec[-Nvol["c"]:]
        if data_only:
            plt.close(fig)
            return datax, datay[t0ind]
        # returns tuble of line objects, thus comma
        line1, = ax.plot(datax, datay[t0ind])

        def init():
            line1.set_ydata(np.ma.array(datax, mask=True))
            ttl.set_text('')
            return line1, ttl

        def animate(tind):
            line1.set_ydata(datay[tind])
            t_current = times[tind]
            tfrac = (t_current - tmin) / (tmax - tmin) * 100
            ttl.set_text(ttl_fmt.format(perc=tfrac))
            return line1, ttl

    else:
        raise Exception("Unexpected plot type argument. See README.md.")

    ani = manim.FuncAnimation(fig,
                              animate,
                              frames=numtimes,
                              interval=50,
                              blit=True,
                              repeat=False,
                              init_func=init)
    if save_flag:
        fig.tight_layout()
        ani.save("mpet_{type}.mp4".format(type=plot_type),
                 fps=25,
                 bitrate=5500)

    return fig, ax, ani
Exemple #14
0
def centroider(target,
               sources,
               output_plots=False,
               gif=False,
               restore=False,
               box_w=8):
    matplotlib.use('TkAgg')
    plt.ioff()
    t1 = time.time()
    pines_path = pines_dir_check()
    short_name = short_name_creator(target)

    kernel = Gaussian2DKernel(x_stddev=1)  #For fixing nans in cutouts.

    #If restore == True, read in existing output and return.
    if restore:
        centroid_df = pd.read_csv(
            pines_path / ('Objects/' + short_name +
                          '/sources/target_and_references_centroids.csv'),
            converters={
                'X Centroids': eval,
                'Y Centroids': eval
            })
        print('Restoring centroider output from {}.'.format(
            pines_path / ('Objects/' + short_name +
                          '/sources/target_and_references_centroids.csv')))
        print('')
        return centroid_df

    #Create subdirectories in sources folder to contain output plots.
    if output_plots:
        subdirs = glob(
            str(pines_path / ('Objects/' + short_name + '/sources')) + '/*/')
        #Delete any source directories that are already there.
        for name in subdirs:
            shutil.rmtree(name)

        #Create new source directories.
        for name in sources['Name']:
            source_path = (
                pines_path /
                ('Objects/' + short_name + '/sources/' + name + '/'))
            os.mkdir(source_path)

    #Read in extra shifts, in case the master image wasn't used for source detection.
    extra_shift_path = pines_path / ('Objects/' + short_name +
                                     '/sources/extra_shifts.txt')
    extra_shifts = pd.read_csv(extra_shift_path,
                               delimiter=' ',
                               names=['Extra X shift', 'Extra Y shift'])
    extra_x_shift = extra_shifts['Extra X shift'][0]
    extra_y_shift = extra_shifts['Extra Y shift'][0]

    np.seterr(
        divide='ignore', invalid='ignore'
    )  #Suppress some warnings we don't care about in median combining.

    #Get list of reduced files for target.
    reduced_path = pines_path / ('Objects/' + short_name + '/reduced')
    reduced_filenames = natsort.natsorted(
        [x.name for x in reduced_path.glob('*red.fits')])
    reduced_files = np.array([reduced_path / i for i in reduced_filenames])

    #Declare a new dataframe to hold the centroid information for all sources we want to track.
    columns = []
    columns.append('Filename')
    columns.append('Seeing')
    columns.append('Time (JD UTC)')
    columns.append('Airmass')

    #Add x/y positions and cenroid flags for every tracked source
    for i in range(0, len(sources)):
        columns.append(sources['Name'][i] + ' Image X')
        columns.append(sources['Name'][i] + ' Image Y')
        columns.append(sources['Name'][i] + ' Cutout X')
        columns.append(sources['Name'][i] + ' Cutout Y')
        columns.append(sources['Name'][i] + ' Centroid Warning')

    centroid_df = pd.DataFrame(index=range(len(reduced_files)),
                               columns=columns)

    log_path = pines_path / ('Logs/')
    log_dates = np.array(
        natsort.natsorted(
            [x.name.split('_')[0] for x in log_path.glob('*.txt')]))

    #Make sure we have logs for all the nights of these data. Need them to account for image shifts.
    nights = list(set([i.name.split('.')[0] for i in reduced_files]))
    for i in nights:
        if i not in log_dates:
            print('ERROR: {} not in {}. Download it from the PINES server.'.
                  format(i + '_log.txt', log_path))
            pdb.set_trace()

    shift_tolerance = 2.0  #Number of pixels that the measured centroid can be away from the expected position in either x or y before trying other centroiding algorithms.
    for i in range(len(sources)):
        #Get the initial source position.
        x_pos = sources['Source Detect X'][i]
        y_pos = sources['Source Detect Y'][i]
        print('')
        print(
            'Getting centroids for {}, ({:3.1f}, {:3.1f}) in source detection image. Source {} of {}.'
            .format(sources['Name'][i], x_pos, y_pos, i + 1, len(sources)))
        if output_plots:
            print('Saving centroid plots to {}.'.format(
                pines_path / ('Objects/' + short_name + '/sources/' +
                              sources['Name'][i] + '/')))
        pbar = ProgressBar()
        for j in pbar(range(len(reduced_files))):
            centroid_df[sources['Name'][i] + ' Centroid Warning'][j] = 0
            file = reduced_files[j]
            image = fits.open(file)[0].data
            #Get the measured image shift for this image.
            log = pines_log_reader(log_path /
                                   (file.name.split('.')[0] + '_log.txt'))
            log_ind = np.where(log['Filename'] == file.name.split('_')[0] +
                               '.fits')[0][0]

            x_shift = float(log['X shift'][log_ind])
            y_shift = float(log['Y shift'][log_ind])

            #Save the filename for readability. Save the seeing for use in variable aperture photometry. Save the time for diagnostic plots.
            if i == 0:
                centroid_df['Filename'][j] = file.name.split('_')[0] + '.fits'
                centroid_df['Seeing'][j] = log['X seeing'][log_ind]
                time_str = fits.open(file)[0].header['DATE-OBS']

                #Correct some formatting issues that can occur in Mimir time stamps.
                if time_str.split(':')[-1] == '60.00':
                    time_str = time_str[0:14] + str(
                        int(time_str.split(':')[-2]) + 1) + ':00.00'
                elif time_str.split(':')[-1] == '010.00':
                    time_str = time_str[0:17] + time_str.split(':')[-1][1:]

                centroid_df['Time (JD UTC)'][j] = julian.to_jd(
                    datetime.datetime.strptime(time_str,
                                               '%Y-%m-%dT%H:%M:%S.%f'))
                centroid_df['Airmass'][j] = log['Airmass'][log_ind]

            nan_flag = False  #Flag indicating if you should not trust the log's shifts. Set to true if x_shift/y_shift are 'nan' or > 30 pixels.

            #If bad shifts were measured for this image, skip.
            if log['Shift quality flag'][log_ind] == 1:
                continue

            if np.isnan(x_shift) or np.isnan(y_shift):
                x_shift = 0
                y_shift = 0
                nan_flag = True

            #If there are clouds, shifts could have been erroneously high...just zero them?
            if abs(x_shift) > 200:
                #x_shift = 0
                nan_flag = True
            if abs(y_shift) > 200:
                #y_shift = 0
                nan_flag = True

            #Apply the shift. NOTE: This relies on having accurate x_shift and y_shift values from the log.
            #If they're incorrect, the cutout will not be in the right place.
            #x_pos = sources['Source Detect X'][i] - x_shift + extra_x_shift
            #y_pos = sources['Source Detect Y'][i] + y_shift - extra_y_shift

            x_pos = sources['Source Detect X'][i] - (x_shift - extra_x_shift)
            y_pos = sources['Source Detect Y'][i] + (y_shift - extra_y_shift)

            #TODO: Make all this its own function.

            #Cutout around the expected position and interpolate over any NaNs (which screw up source detection).
            cutout = interpolate_replace_nans(
                image[int(y_pos - box_w):int(y_pos + box_w) + 1,
                      int(x_pos - box_w):int(x_pos + box_w) + 1],
                kernel=Gaussian2DKernel(x_stddev=0.5))

            #interpolate_replace_nans struggles with edge pixels, so shave off edge_shave pixels in each direction of the cutout.
            edge_shave = 1
            cutout = cutout[edge_shave:len(cutout) - edge_shave,
                            edge_shave:len(cutout) - edge_shave]

            vals, lower, upper = sigmaclip(
                cutout, low=1.5,
                high=2.5)  #Get sigma clipped stats on the cutout
            med = np.nanmedian(vals)
            std = np.nanstd(vals)

            try:
                centroid_x_cutout, centroid_y_cutout = centroid_2dg(
                    cutout - med)  #Perform centroid detection on the cutout.
            except:
                pdb.set_trace()

            centroid_x = centroid_x_cutout + int(
                x_pos
            ) - box_w + edge_shave  #Translate the detected centroid from the cutout coordinates back to the full-frame coordinates.
            centroid_y = centroid_y_cutout + int(y_pos) - box_w + edge_shave

            # if i == 0:
            #     qp(cutout)
            #     plt.plot(centroid_x_cutout, centroid_y_cutout, 'rx')

            #     # qp(image)
            #     # plt.plot(centroid_x, centroid_y, 'rx')
            #     pdb.set_trace()

            #If the shifts in the log are not 'nan' or > 200 pixels, check if the measured shifts are within shift_tolerance pixels of the expected position.
            #   If they aren't, try alternate centroiding methods to try and find it.

            #Otherwise, use the shifts as measured with centroid_1dg. PINES_watchdog likely failed while observing, and we don't expect the centroids measured here to actually be at the expected position.
            if not nan_flag:
                #Try a 2D Gaussian detection.
                if (abs(centroid_x - x_pos) > shift_tolerance) or (
                        abs(centroid_y - y_pos) > shift_tolerance):
                    centroid_x_cutout, centroid_y_cutout = centroid_2dg(
                        cutout - med)
                    centroid_x = centroid_x_cutout + int(x_pos) - box_w
                    centroid_y = centroid_y_cutout + int(y_pos) - box_w

                    #If that fails, try a COM detection.
                    if (abs(centroid_x - x_pos) > shift_tolerance) or (
                            abs(centroid_y - y_pos) > shift_tolerance):
                        centroid_x_cutout, centroid_y_cutout = centroid_com(
                            cutout - med)
                        centroid_x = centroid_x_cutout + int(x_pos) - box_w
                        centroid_y = centroid_y_cutout + int(y_pos) - box_w

                        #If that fails, try masking source and interpolate over any bad pixels that aren't in the bad pixel mask, then redo 1D gaussian detection.
                        if (abs(centroid_x - x_pos) > shift_tolerance) or (
                                abs(centroid_y - y_pos) > shift_tolerance):
                            mask = make_source_mask(cutout,
                                                    nsigma=4,
                                                    npixels=5,
                                                    dilate_size=3)
                            vals, lo, hi = sigmaclip(cutout[~mask])
                            bad_locs = np.where((mask == False) & (
                                (cutout > hi) | (cutout < lo)))
                            cutout[bad_locs] = np.nan
                            cutout = interpolate_replace_nans(
                                cutout, kernel=Gaussian2DKernel(x_stddev=0.5))

                            centroid_x_cutout, centroid_y_cutout = centroid_1dg(
                                cutout - med)
                            centroid_x = centroid_x_cutout + int(x_pos) - box_w
                            centroid_y = centroid_y_cutout + int(y_pos) - box_w

                            #Try a 2D Gaussian detection on the interpolated cutout
                            if (abs(centroid_x - x_pos) > shift_tolerance) or (
                                    abs(centroid_y - y_pos) > shift_tolerance):
                                centroid_x_cutout, centroid_y_cutout = centroid_2dg(
                                    cutout - med)
                                centroid_x = centroid_x_cutout + int(
                                    x_pos) - box_w
                                centroid_y = centroid_y_cutout + int(
                                    y_pos) - box_w

                                #Try a COM on the interpolated cutout.
                                if (abs(centroid_x - x_pos) > shift_tolerance
                                    ) or (abs(centroid_y - y_pos) >
                                          shift_tolerance):
                                    centroid_x_cutout, centroid_y_cutout = centroid_com(
                                        cutout)
                                    centroid_x = centroid_x_cutout + int(
                                        x_pos) - box_w
                                    centroid_y = centroid_y_cutout + int(
                                        y_pos) - box_w

                                    #Last resort: try cutting off the edge of the cutout. Edge pixels can experience poor interpolation, and this sometimes helps.
                                    if (abs(centroid_x - x_pos) >
                                            shift_tolerance) or (
                                                abs(centroid_y - y_pos) >
                                                shift_tolerance):
                                        cutout = cutout[1:-1, 1:-1]
                                        centroid_x_cutout, centroid_y_cutout = centroid_1dg(
                                            cutout - med)
                                        centroid_x = centroid_x_cutout + int(
                                            x_pos) - box_w + 1
                                        centroid_y = centroid_y_cutout + int(
                                            y_pos) - box_w + 1

                                        #Try with a 2DG
                                        if (abs(centroid_x - x_pos) >
                                                shift_tolerance) or (
                                                    abs(centroid_y - y_pos) >
                                                    shift_tolerance):
                                            centroid_x_cutout, centroid_y_cutout = centroid_2dg(
                                                cutout - med)
                                            centroid_x = centroid_x_cutout + int(
                                                x_pos) - box_w + 1
                                            centroid_y = centroid_y_cutout + int(
                                                y_pos) - box_w + 1

                                            #If ALL that fails, report the expected position as the centroid.
                                            if (abs(centroid_x - x_pos) >
                                                    shift_tolerance) or (
                                                        abs(centroid_y - y_pos)
                                                        > shift_tolerance):
                                                print(
                                                    'WARNING: large centroid deviation measured, returning predicted position'
                                                )
                                                print('')
                                                centroid_df[
                                                    sources['Name'][i] +
                                                    ' Centroid Warning'][j] = 1
                                                centroid_x = x_pos
                                                centroid_y = y_pos
                                                #pdb.set_trace()

            #Check that your measured position is actually on the detector.
            if (centroid_x < 0) or (centroid_y < 0) or (centroid_x > 1023) or (
                    centroid_y > 1023):
                #Try a quick mask/interpolation of the cutout.
                mask = make_source_mask(cutout,
                                        nsigma=3,
                                        npixels=5,
                                        dilate_size=3)
                vals, lo, hi = sigmaclip(cutout[~mask])
                bad_locs = np.where((mask == False)
                                    & ((cutout > hi) | (cutout < lo)))
                cutout[bad_locs] = np.nan
                cutout = interpolate_replace_nans(
                    cutout, kernel=Gaussian2DKernel(x_stddev=0.5))
                centroid_x, centroid_y = centroid_2dg(cutout - med)
                centroid_x += int(x_pos) - box_w
                centroid_y += int(y_pos) - box_w
                if (centroid_x < 0) or (centroid_y < 0) or (
                        centroid_x > 1023) or (centroid_y > 1023):
                    print(
                        'WARNING: large centroid deviation measured, returning predicted position'
                    )
                    print('')
                    centroid_df[sources['Name'][i] +
                                ' Centroid Warning'][j] = 1
                    centroid_x = x_pos
                    centroid_y = y_pos
                    #pdb.set_trace()

            #Check to make sure you didn't measure nan's.
            if np.isnan(centroid_x):
                centroid_x = x_pos
                print(
                    'NaN returned from centroid algorithm, defaulting to target position in source_detct_image.'
                )
            if np.isnan(centroid_y):
                centroid_y = y_pos
                print(
                    'NaN returned from centroid algorithm, defaulting to target position in source_detct_image.'
                )

            #Record the image and relative cutout positions.
            centroid_df[sources['Name'][i] + ' Image X'][j] = centroid_x
            centroid_df[sources['Name'][i] + ' Image Y'][j] = centroid_y
            centroid_df[sources['Name'][i] +
                        ' Cutout X'][j] = centroid_x_cutout
            centroid_df[sources['Name'][i] +
                        ' Cutout Y'][j] = centroid_y_cutout

            if output_plots:
                #Plot
                lock_x = int(centroid_df[sources['Name'][i] + ' Image X'][0])
                lock_y = int(centroid_df[sources['Name'][i] + ' Image Y'][0])
                norm = ImageNormalize(data=cutout, interval=ZScaleInterval())
                plt.imshow(image, origin='lower', norm=norm)
                plt.plot(centroid_x, centroid_y, 'rx')
                ap = CircularAperture((centroid_x, centroid_y), r=5)
                ap.plot(lw=2, color='b')
                plt.ylim(lock_y - 30, lock_y + 30 - 1)
                plt.xlim(lock_x - 30, lock_x + 30 - 1)
                plt.title('CENTROID DIAGNOSTIC PLOT\n' + sources['Name'][i] +
                          ', ' + reduced_files[j].name + ' (image ' +
                          str(j + 1) + ' of ' + str(len(reduced_files)) + ')',
                          fontsize=10)
                plt.text(centroid_x,
                         centroid_y + 0.5,
                         '(' + str(np.round(centroid_x, 1)) + ', ' +
                         str(np.round(centroid_y, 1)) + ')',
                         color='r',
                         ha='center')
                plot_output_path = (
                    pines_path /
                    ('Objects/' + short_name + '/sources/' +
                     sources['Name'][i] + '/' + str(j).zfill(4) + '.jpg'))
                plt.gca().set_axis_off()
                plt.subplots_adjust(top=1,
                                    bottom=0,
                                    right=1,
                                    left=0,
                                    hspace=0,
                                    wspace=0)
                plt.margins(0, 0)
                plt.gca().xaxis.set_major_locator(plt.NullLocator())
                plt.gca().yaxis.set_major_locator(plt.NullLocator())
                plt.savefig(plot_output_path,
                            bbox_inches='tight',
                            pad_inches=0,
                            dpi=150)
                plt.close()

        if gif:
            gif_path = (pines_path / ('Objects/' + short_name + '/sources/' +
                                      sources['Name'][i] + '/'))
            gif_maker(path=gif_path, fps=10)

    output_filename = pines_path / (
        'Objects/' + short_name +
        '/sources/target_and_references_centroids.csv')
    #centroid_df.to_csv(pines_path/('Objects/'+short_name+'/sources/target_and_references_centroids.csv'))

    print('Saving centroiding output to {}.'.format(output_filename))
    with open(output_filename, 'w') as f:
        for j in range(len(centroid_df)):
            #Write the header line.
            if j == 0:
                f.write('{:<17s}, '.format('Filename'))
                f.write('{:<15s}, '.format('Time (JD UTC)'))
                f.write('{:<6s}, '.format('Seeing'))
                f.write('{:<7s}, '.format('Airmass'))
                for i in range(len(sources['Name'])):
                    n = sources['Name'][i]
                    if i != len(sources['Name']) - 1:
                        f.write(
                            '{:<23s}, {:<23s}, {:<24s}, {:<24s}, {:<34s}, '.
                            format(n + ' Image X', n + ' Image Y',
                                   n + ' Cutout X', n + ' Cutout Y',
                                   n + ' Centroid Warning'))
                    else:
                        f.write(
                            '{:<23s}, {:<23s}, {:<24s}, {:<24s}, {:<34s}\n'.
                            format(n + ' Image X', n + ' Image Y',
                                   n + ' Cutout X', n + ' Cutout Y',
                                   n + ' Centroid Warning'))

            #Write in the data lines.
            try:
                f.write('{:<17s}, '.format(centroid_df['Filename'][j]))
                f.write('{:<15.7f}, '.format(centroid_df['Time (JD UTC)'][j]))
                f.write('{:<6.1f}, '.format(float(centroid_df['Seeing'][j])))
                f.write('{:<7.2f}, '.format(centroid_df['Airmass'][j]))
            except:
                pdb.set_trace()

            for i in range(len(sources['Name'])):
                n = sources['Name'][i]
                if i != len(sources['Name']) - 1:
                    format_string = '{:<23.4f}, {:<23.4f}, {:<24.4f}, {:<24.4f}, {:<34d}, '
                else:
                    format_string = '{:<23.4f}, {:<23.4f}, {:<24.4f}, {:<24.4f}, {:<34d}\n'

                f.write(
                    format_string.format(
                        centroid_df[n + ' Image X'][j],
                        centroid_df[n + ' Image Y'][j],
                        centroid_df[n + ' Cutout X'][j],
                        centroid_df[n + ' Cutout Y'][j],
                        centroid_df[n + ' Centroid Warning'][j]))
    np.seterr(divide='warn', invalid='warn')
    print('')
    print('centroider runtime: {:.2f} minutes.'.format(
        (time.time() - t1) / 60))
    print('')
    return centroid_df
Exemple #15
0
def plot_test_query(state,
                    better_action,
                    worse_action,
                    feature_mat,
                    equal_pref=False):

    plt.figure()
    ax = plt.axes()
    count = 0
    rows, cols = len(feature_mat), len(feature_mat[0])
    if better_action is "^":
        plot_arrow(state, cols, ax, "up")
    elif better_action is "v":
        plot_arrow(state, cols, ax, "down")
    elif better_action is ">":
        plot_arrow(state, cols, ax, "right")
    elif better_action is "<":
        plot_arrow(state, cols, ax, "left")

    if equal_pref:
        if worse_action is "^":
            plot_arrow(state, cols, ax, "up")
        elif worse_action is "v":
            plot_arrow(state, cols, ax, "down")
        elif worse_action is ">":
            plot_arrow(state, cols, ax, "right")
        elif worse_action is "<":
            plot_arrow(state, cols, ax, "left")

    else:

        if worse_action is "^":
            plot_dashed_arrow(state, cols, ax, "up")
        elif worse_action is "v":
            plot_dashed_arrow(state, cols, ax, "down")
        elif worse_action is ">":
            plot_dashed_arrow(state, cols, ax, "right")
        elif worse_action is "<":
            plot_dashed_arrow(state, cols, ax, "left")

    mat = [[0 if fvec is None else fvec.index(1) + 1 for fvec in row]
           for row in feature_mat]
    print(mat)
    #convert feature_mat into colors
    #heatmap =  plt.imshow(mat, cmap="Reds", interpolation='none', aspect='equal')
    cmap = colors.ListedColormap([
        'black', 'white', 'tab:red', 'tab:blue', 'tab:green', 'tab:purple',
        'tab:orange', 'tab:gray', 'tab:cyan'
    ])
    im = plt.imshow(mat, cmap=cmap, interpolation='none', aspect='equal')

    ax = plt.gca()

    ax.set_xticks(np.arange(-.5, cols, 1), minor=True)
    ax.set_yticks(np.arange(-.5, rows, 1), minor=True)
    #ax.grid(which='minor', axis='both', linestyle='-', linewidth=5, color='k')
    # Gridlines based on minor ticks
    ax.grid(which='minor', color='k', linestyle='-', linewidth=5)
    ax.xaxis.set_major_formatter(plt.NullFormatter())
    ax.yaxis.set_major_formatter(plt.NullFormatter())
    ax.yaxis.set_major_locator(plt.NullLocator())
    ax.xaxis.set_major_locator(plt.NullLocator())
    #cbar = plt.colorbar(heatmap)
    #cbar.ax.tick_params(labelsize=20)
    plt.show()
Exemple #16
0
def create_mosaic(input_volume, output_filepath=None, label_volume=None, generate_outline=True, mask_value=0, step=1, dim=2, cols=8, label_buffer=5, rotate_90=3, flip=True):

    """This creates a mosaic of 2D images from a 3D Volume.
    
    Parameters
    ----------
    input_volume : TYPE
        Any neuroimaging file with a filetype supported by qtim_tools, or existing numpy array.
    output_filepath : None, optional
        Where to save your output, in a filetype supported by matplotlib (e.g. .png). If 
    label_volume : None, optional
        Whether to create your mosaic with an attached label filepath / numpy array. Will not perform volume transforms from header (yet)
    generate_outline : bool, optional
        If True, will generate outlines for label_volumes, instead of filled-in areas. Default is True.
    mask_value : int, optional
        Background value for label volumes. Default is 0.
    step : int, optional
        Will generate an image for every [step] slice. Default is 1.
    dim : int, optional
        Mosaic images will be sliced along this dimension. Default is 2, which often corresponds to axial.
    cols : int, optional
        How many columns in your output mosaic. Rows will be determined automatically. Default is 8.
    label_buffer : int, optional
        Images more than [label_buffer] slices away from a slice containing a label pixel will note be included. Default is 5.
    rotate_90 : int, optional
        If the output mosaic is incorrectly rotated, you may rotate clockwise [rotate_90] times. Default is 3.
    flip : bool, optional
        If the output is incorrectly flipped, you may set to True to flip the data. Default is True.
    
    No Longer Returned
    ------------------
    
    Returns
    -------
    output_array: N+1 or N-dimensional array
        The generated mosaic array.
    
    """

    image_numpy = read_image_files(input_volume)
    if step is None:
        step = 1

    if label_volume is not None:

        label_numpy = read_image_files(label_volume)

        if generate_outline:
            label_numpy = generate_label_outlines(label_numpy, dim, mask_value)

        # This is fun in a wacky way, but could probably be done more concisely and effeciently.
        mosaic_selections = []
        for i in xrange(label_numpy.shape[dim]):
            label_slice = np.squeeze(label_numpy[[slice(None) if k != dim else slice(i, i + 1) for k in xrange(3)]])
            if np.sum(label_slice) != 0:
                mosaic_selections += range(i - label_buffer, i + label_buffer)
        mosaic_selections = np.unique(mosaic_selections)
        mosaic_selections = mosaic_selections[mosaic_selections >= 0]
        mosaic_selections = mosaic_selections[mosaic_selections <= image_numpy.shape[dim]]
        mosaic_selections = mosaic_selections[::step]

        color_range_image = [np.min(image_numpy), np.max(image_numpy)]
        color_range_label = [np.min(label_numpy), np.max(label_numpy)]

        # One day, specify rotations by affine matrix.
        # Is test slice necessary? Operate directly on shape if possible.
        test_slice = np.rot90(np.squeeze(image_numpy[[slice(None) if k != dim else slice(0, 1) for k in xrange(3)]]), rotate_90)
        slice_width = test_slice.shape[1]
        slice_height = test_slice.shape[0]

        mosaic_image_numpy = np.zeros((int(slice_height * np.ceil(float(len(mosaic_selections)) / float(cols))), int(test_slice.shape[1] * cols)), dtype=float)
        mosaic_label_numpy = np.zeros_like(mosaic_image_numpy)
        
        row_index = 0
        col_index = 0

        for i in mosaic_selections:
            image_slice = np.rot90(np.squeeze(image_numpy[[slice(None) if k != dim else slice(i, i + 1) for k in xrange(3)]]), rotate_90)
            label_slice = np.rot90(np.squeeze(label_numpy[[slice(None) if k != dim else slice(i, i + 1) for k in xrange(3)]]), rotate_90)

            # Again, specify from affine matrix if possible.
            if flip:
                image_slice = np.fliplr(image_slice)
                label_slice = np.fliplr(label_slice)

            if image_slice.size > 0:
                mosaic_image_numpy[int(row_index):int(row_index + slice_height), int(col_index):int(col_index + slice_width)] = image_slice
                mosaic_label_numpy[int(row_index):int(row_index + slice_height), int(col_index):int(col_index + slice_width)] = label_slice

            if col_index == mosaic_image_numpy.shape[1] - slice_width:
                col_index = 0
                row_index += slice_height 
            else:
                col_index += slice_width

        mosaic_label_numpy = np.ma.masked_where(mosaic_label_numpy == 0, mosaic_label_numpy)

        if output_filepath is not None:
            plt.figure(figsize=(mosaic_image_numpy.shape[0] / 100, mosaic_image_numpy.shape[1] / 100), dpi=100, frameon=False)
            plt.margins(0, 0)
            plt.gca().set_axis_off()
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.imshow(mosaic_image_numpy, 'gray', vmin=color_range_image[0], vmax=color_range_image[1], interpolation='none')
            plt.imshow(mosaic_label_numpy, 'jet', vmin=color_range_label[0], vmax=color_range_label[1], interpolation='none')
            
            plt.savefig(output_filepath, bbox_inches='tight', pad_inches=0.0, dpi=1000)
            plt.clf()
            plt.close()

        return mosaic_image_numpy

    else:

        color_range_image = [np.min(image_numpy), np.max(image_numpy)]

        test_slice = np.rot90(np.squeeze(image_numpy[[slice(None) if k != dim else slice(0, 1) for k in xrange(3)]]), rotate_90)
        slice_width = test_slice.shape[1]
        slice_height = test_slice.shape[0]

        mosaic_selections = np.arange(image_numpy.shape[dim])[::step]
        mosaic_image_numpy = np.zeros((int(slice_height * np.ceil(float(len(mosaic_selections)) / float(cols))), int(test_slice.shape[1] * cols)), dtype=float)

        row_index = 0
        col_index = 0

        for i in mosaic_selections:
            image_slice = np.squeeze(image_numpy[[slice(None) if k != dim else slice(i, i + 1) for k in xrange(3)]])

            image_slice = np.rot90(image_slice, rotate_90)
            
            if flip:
                image_slice = np.fliplr(image_slice)

            mosaic_image_numpy[int(row_index):int(row_index + slice_height), int(col_index):int(col_index + slice_width)] = image_slice

            if col_index == mosaic_image_numpy.shape[1] - slice_width:
                col_index = 0
                row_index += slice_height 
            else:
                col_index += slice_width

        if output_filepath is not None:
            plt.figure(figsize=(mosaic_image_numpy.shape[0] / 100, mosaic_image_numpy.shape[1] / 100), dpi=100, frameon=False)
            plt.margins(0, 0)
            plt.gca().set_axis_off()
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.imshow(mosaic_image_numpy, 'gray', vmin=color_range_image[0], vmax=color_range_image[1], interpolation='none')

            plt.savefig(output_filepath, bbox_inches='tight', pad_inches=0.0, dpi=500) 
            plt.clf()
            plt.close()

        return mosaic_image_numpy
Exemple #17
0
def train(dataloader, model, optimizer, log, loss_file, args):

    losses = [AverageMeter() for _ in range(2)]
    length_loader = len(dataloader)
    D1s = [AverageMeter() for _ in range(2)]

    start_full_time = time.time()
    for batch_idx, (imgL, imgR, disp_L) in enumerate(dataloader):
        imgL = imgL.float().cuda()
        imgR = imgR.float().cuda()
        disp_L = disp_L.float().cuda()
        #print('train imgR size:', imgR.shape)

        optimizer.zero_grad()
        mask = disp_L > 0
        mask = mask * (disp_L < 192)
        mask.detach_()

        single_update_time = time.time()

        #outputs = model(imgL, imgR)
        if args.adaptation_type == "no_supervise":
            model.eval()
            with torch.no_grad():
                pred, mono_loss = model(imgL, imgR)

            outputs = [torch.squeeze(output, 1) for output in pred]

            num_out = len(pred)
            loss = [
                args.loss_weights[x] * F.smooth_l1_loss(
                    outputs[x][mask], disp_L[mask], size_average=True)
                for x in range(num_out)
            ]

            num_out = len(pred)

        elif args.adaptation_type == "self_supervise":
            model.train()

            pred, mono_loss = model(imgL, imgR)
            outputs = [torch.squeeze(output, 1) for output in pred]
            num_out = len(pred)
            loss = [
                args.loss_weights[x] * F.smooth_l1_loss(
                    outputs[x][mask], disp_L[mask], size_average=True)
                for x in range(num_out)
            ]

            sum(mono_loss).backward()

            optimizer.step()

        elif args.adaptation_type == "GT_supervise":
            model.train()

            pred, mono_loss = model(imgL, imgR)

            outputs = [torch.squeeze(output, 1) for output in pred]

            num_out = len(pred)
            loss = [
                args.loss_weights[x] * F.smooth_l1_loss(
                    outputs[x][mask], disp_L[mask], size_average=True)
                for x in range(num_out)
            ]

            sum(loss).backward()
            optimizer.step()

        print('sigle_update_time: {:.4f} seconds'.format(time.time() -
                                                         single_update_time))
        # image out and error estimation

        # three pixel error

        output = torch.squeeze(pred[1], 1)
        D1s[1].update(error_estimating(output, disp_L).item())
        print('output size:', output.shape)

        # save the adaptation disparity
        if args.save_disparity:

            plt.imshow(output.squeeze(0).cpu().detach().numpy())
            plt.axis('off')

            plt.gcf().set_size_inches(1216 / 100, 320 / 100)
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.subplots_adjust(top=1,
                                bottom=0,
                                left=0,
                                right=1,
                                hspace=0,
                                wspace=0)
            plt.margins(0, 0)

            plt.savefig(args.save_path + '/disparity/{}.png'.format(batch_idx))

        # if args.save_disparity:
        #
        #     imgL = imgL.squeeze(0).permute(1,2,0)
        #     #print("imgL size:", imgL.shape)
        #     plt.imshow(imgL.cpu().detach().numpy())
        #     plt.axis('off')
        #
        #     plt.gcf().set_size_inches(1216 / 100, 320 / 100)
        #     plt.gca().xaxis.set_major_locator(plt.NullLocator())
        #     plt.gca().yaxis.set_major_locator(plt.NullLocator())
        #     plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
        #     plt.margins(0, 0)
        #
        #     plt.savefig(args.save_path + '/disparity/{}.png'.format(batch_idx))
        #

        loss_file.write('{:.4f}\n'.format(D1s[1].val))

        for idx in range(num_out):
            losses[idx].update(loss[idx].item())

        info_str = [
            'Stage {} = {:.2f}({:.2f})'.format(x, losses[x].val, losses[x].avg)
            for x in range(num_out)
        ]
        info_str = '\t'.join(info_str)

        log.info('Epoch{} [{}/{}] {}'.format(1, batch_idx, length_loader,
                                             info_str))

    end_time = time.time()

    log.info(
        'full training time = {:.2f} Hours, full train time = {:.4f} seconds'.
        format((end_time - start_full_time) / 3600,
               end_time - start_full_time))

    # summary
    info_str = ', '.join(
        ['Stage {}={:.4f}'.format(x, D1s[x].avg) for x in range(num_out)])

    log.info('Average test 3-Pixel Error = ' + info_str)

    info_str = '\t'.join(
        ['Stage {} = {:.2f}'.format(x, losses[x].avg) for x in range(num_out)])
    log.info('Average train loss = ' + info_str)

    loss_file.close()
Exemple #18
0
    y_lut, x_luts = lut
    u_rv = numpy.random.random((N, 2))
    samples = numpy.zeros(u_rv.shape)
    for i, (x, y) in enumerate(u_rv):
        print i
        ys = y_lut(y)
        x_bin = int(ys / RESOLUTION)
        xs = x_luts[x_bin](x)
        samples[i][0] = xs
        samples[i][1] = ys

    return samples


if __name__ == '__main__':
    from skimage import io
    density_img = io.imread('batman.jpg', True)
    lut_2d = generate_lut(density_img)
    samples = sample_2d(lut_2d, 50000)

    from matplotlib import pyplot
    fig, (ax0, ax1) = pyplot.subplots(ncols=2, figsize=(9, 4))
    fig.canvas.set_window_title('Test 2D Sampling')
    ax0.imshow(density_img, cmap='gray')
    ax0.xaxis.set_major_locator(pyplot.NullLocator())
    ax0.yaxis.set_major_locator(pyplot.NullLocator())

    ax1.axis('equal')
    ax1.axis([0, 1, 0, 1])
    ax1.plot(samples[:, 0], samples[:, 1], 'k,')
pyplot.show()
def plot():
    map_path = file_path + "/resources/sf_block_groups/sf_block_groups_nowater.geojson"
    coc_path = file_path + "/resources/sf_block_groups/coc"
    plot_path = file_path + "/resources/sf_data/sf_overspace_plot_data.json"
    fig_path = file_path + "/results/sf_change_overspace.pdf"

    # Read data.
    with open(plot_path, "r") as plot_file:
        data = json.loads(plot_file.read().strip("\n"))

    coc = gpd.read_file(coc_path)
    coc = coc[coc["GEOID"].astype("int") -
              coc["GEOID"].astype("int") % 1000000 == 6075000000]
    coc = coc[coc["GEOID"].astype("int") != 6075017902]
    coc = coc[coc["COCFLAG__1"] == 1]
    coc = coc.to_crs({"init": "epsg:4326"})

    map = gpd.read_file(map_path)
    map["geoid"] = map["stfid"].astype("int")
    map = map[["geoid", "geometry"]]
    map["bg_lng"] = map.centroid.apply(lambda p: p.x)
    map["bg_lat"] = map.centroid.apply(lambda p: p.y)
    map = map[map["geoid"] != 60750179021]

    # Get supply curve data
    sup = pd.DataFrame.from_dict(data["sup"])
    sup["geoid"] = data["index"]
    sup = sup[sup["geoid"] != 60750601001]
    sup = sup[sup["geoid"] != 60750604001]
    sup = sup[sup["geoid"] != 60750332011]
    sup = sup[sup["geoid"] != 60750610002]
    sup = sup[sup["geoid"] != 60750264022]
    sup = sup[sup["geoid"] != 60750258002]
    sup[sup["geoid"] == 60750610001] = 1
    sup = map.merge(sup, on="geoid", how="left")

    # Get price curve data
    pri = pd.DataFrame.from_dict(data["pri"])
    pri["geoid"] = data["index"]
    pri = map.merge(pri, on="geoid", how="left")

    # Plot parameter and setting.
    font = FontProperties()
    font.set_weight("bold")
    font.set_size(10)
    matplotlib.rcParams.update({"font.size": 6})
    alpha = 0.5
    alpha2 = 0.3
    k = 2
    bar_cons = 0.66
    bar_mv = 0.27
    for i in [0, 1, 2, 3, 4]:
        ax[i].set_xlim([-122.513, -122.355])
        ax[i].set_ylim([37.707, 37.833])
        ax[i].set_axis_off()
        ax[i].xaxis.set_major_locator(plt.NullLocator())
        ax[i].yaxis.set_major_locator(plt.NullLocator())
        coc.plot(ax=ax[i], linewidth=0.5, alpha=0)
    app_list = ["uber", "lyft", "taxi"]
    cmap = "RdYlGn"

    f = 0
    for i in [0, 1, 2]:
        sup["plot"] = sup[app_list[i]]  #/ sup["area"] * 581
        knn = neighbors.KNeighborsRegressor(k, "distance")  # Fill empty area.
        train_x = sup[["plot", "bg_lat",
                       "bg_lng"]].dropna()[["bg_lat", "bg_lng"]].values
        train_y = sup["plot"].dropna().values
        predict_x = sup[["bg_lat", "bg_lng"]].values
        sup["plot"] = knn.fit(train_x, train_y).predict(predict_x)
        vmin = sup["plot"].min()
        vmax = sup["plot"].quantile(0.95)
        # plot
        sup.plot(ax=ax[i],
                 linewidth=0,
                 column="plot",
                 cmap=cmap,
                 alpha=alpha,
                 k=10,
                 vmin=vmin,
                 vmax=vmax)
        ax[i].set_title(upperfirst(app_list[i]) + " Supply",
                        fontproperties=font)
        fig = ax[i].get_figure()
        cax = fig.add_axes([0.128 + 0.087 * i, 0.07, 0.07, 0.02])
        sm = plt.cm.ScalarMappable(cmap=cmap,
                                   norm=plt.Normalize(vmin=vmin, vmax=vmax))
        sm._A = []
        fig.colorbar(sm,
                     cax=cax,
                     alpha=alpha2,
                     extend="both",
                     orientation="horizontal")

    cmap = "RdYlGn_r"
    f = 2
    for i in [3, 4]:
        pri["plot"] = (pri[app_list[i - 3]] - 1) * 100
        knn = neighbors.KNeighborsRegressor(k, "distance")  # Fill empty area.
        train_x = pri[["plot", "bg_lat",
                       "bg_lng"]].dropna()[["bg_lat", "bg_lng"]].values
        train_y = pri["plot"].dropna().values
        predict_x = pri[["bg_lat", "bg_lng"]].values
        pri["plot"] = knn.fit(train_x, train_y).predict(predict_x)
        vmin = 0
        vmax = 12
        print pri["plot"].max() - pri["plot"].min()
        print pri["plot"].std()
        # plot
        pri.plot(ax=ax[i],
                 linewidth=0,
                 column="plot",
                 cmap=cmap,
                 alpha=alpha,
                 k=10,
                 vmin=vmin,
                 vmax=vmax)
        ax[i].set_title(upperfirst(app_list[i - 3]) + " Price",
                        fontproperties=font)
        fig = ax[i].get_figure()
        cax = fig.add_axes([0.128 + 0.087 * i, 0.07, 0.07, 0.02])
        sm = plt.cm.ScalarMappable(cmap=cmap,
                                   norm=plt.Normalize(vmin=vmin, vmax=vmax))
        sm._A = []
        fig.colorbar(sm,
                     cax=cax,
                     alpha=alpha2,
                     extend="both",
                     orientation="horizontal")

    map_path = file_path + "/resources/nyc_block_groups/nyc_bg_with_data_acs15.geojson"
    plot_path = file_path + "/resources/nyc_data/nyc_overspace_plot_data.json"
    fig_path = file_path + "/results/nyc_change_overspace.pdf"

    # Read data.
    with open(plot_path, "r") as plot_file:
        data = json.loads(plot_file.read().strip("\n"))

    map = gpd.read_file(map_path)
    coc = map.sort_values("income")[:80]
    map = map[map["population"].astype("float") > 10.0]
    map["geoid"] = map["geo_id"].astype("int")
    map = map[["geoid", "geometry"]]
    map["bg_lng"] = map.centroid.apply(lambda p: p.x)
    map["bg_lat"] = map.centroid.apply(lambda p: p.y)

    # Get supply curve data
    sup = pd.DataFrame.from_dict(data["sup"])
    sup["geoid"] = data["index"]
    sup = map.merge(sup, on="geoid", how="left")

    # Get price curve data
    pri = pd.DataFrame.from_dict(data["pri"])
    pri["geoid"] = data["index"]
    pri = pri[pri["uber"] > 1.0]
    pri = pri[pri["lyft"] > 1.0]
    pri = map.merge(pri, on="geoid", how="left")

    # Plot parameter and setting.
    bar_cons = 0.66
    bar_mv = 0.27
    for i in [5, 6, 7, 8]:
        ax[i].set_xlim([-74.055, -73.88])
        ax[i].set_ylim([40.64, 40.90])
        ax[i].set_axis_off()
        ax[i].xaxis.set_major_locator(plt.NullLocator())
        ax[i].yaxis.set_major_locator(plt.NullLocator())
        coc.plot(ax=ax[i], linewidth=0.5, alpha=0)
    app_list = ["uber", "lyft"]
    cmap = "RdYlGn"

    f = 0
    for i in [5, 6]:
        sup["plot"] = sup[app_list[i - 5]]
        vmin = sup["plot"].min()
        if i == 5:
            vmax = 7  #sup["plot"].quantile(0.9)
        else:
            vmax = 5
        # plot
        sup.plot(ax=ax[i],
                 linewidth=0,
                 column="plot",
                 cmap=cmap,
                 alpha=alpha,
                 k=10,
                 vmin=vmin,
                 vmax=vmax)
        ax[i].set_title(upperfirst(app_list[i - 5]) + " Supply",
                        fontproperties=font)
        fig = ax[i].get_figure()
        cax = fig.add_axes([0.132 + 0.087 * i, 0.07, 0.07, 0.02])
        sm = plt.cm.ScalarMappable(cmap=cmap,
                                   norm=plt.Normalize(vmin=vmin, vmax=vmax))
        sm._A = []
        fig.colorbar(sm,
                     cax=cax,
                     alpha=alpha2,
                     extend="both",
                     orientation="horizontal")

    cmap = "RdYlGn_r"
    f = 2
    for i in [7, 8]:
        pri["plot"] = (pri[app_list[i - 3 - 4]] - 1) * 100
        knn = neighbors.KNeighborsRegressor(k, "distance")  # Fill empty area.
        train_x = pri[["plot", "bg_lat",
                       "bg_lng"]].dropna()[["bg_lat", "bg_lng"]].values
        train_y = pri["plot"].dropna().values
        predict_x = pri[["bg_lat", "bg_lng"]].values
        pri["plot"] = knn.fit(train_x, train_y).predict(predict_x)
        vmin = 0
        if i == 7:
            vmax = 2.5  #sup["plot"].quantile(0.9)
        else:
            vmax = 7
        print pri["plot"].max() - pri["plot"].min()
        print pri["plot"].std()
        # plot
        pri.plot(ax=ax[i],
                 linewidth=0,
                 column="plot",
                 cmap=cmap,
                 alpha=alpha,
                 k=10,
                 vmin=vmin,
                 vmax=vmax)
        ax[i].set_title(upperfirst(app_list[i - 3 - 4]) + " Price",
                        fontproperties=font)
        fig = ax[i].get_figure()
        cax = fig.add_axes([0.132 + 0.087 * i, 0.07, 0.07, 0.02])
        sm = plt.cm.ScalarMappable(cmap=cmap,
                                   norm=plt.Normalize(vmin=vmin, vmax=vmax))
        sm._A = []
        fig.colorbar(sm,
                     cax=cax,
                     alpha=alpha2,
                     extend="both",
                     orientation="horizontal")
Exemple #20
0
    ## stack data to be able to plot them with imshow
    stacked_temps = np.stack((temps, temps))

    #min and max values for the colormap !this value deviates from the warming stripes where a standard deviation of +/-2.6 was chosen
    vmin = -1.7 * std
    vmax = 1.7 * std

    ## plotting
    fig = plt.figure(
        figsize=(16, 9))  #adjust figsize, for example for cubic figure
    #plot the image, with manual color bar defined above in custom_div_cmap function
    cmap = custom_div_cmap()
    img = plt.imshow(stacked_temps,
                     cmap=cmap,
                     aspect='auto',
                     vmin=vmin,
                     vmax=vmax,
                     interpolation='none')
    #this just turns all labels, axis etc off so that there are only the stripes
    plt.gca().set_axis_off()
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    #save in your desired directory
    plt.savefig("stripes_" + temps.name + '_' + month + '_' +
                str(data['Jahr'].min()) + '-' + str(data['Jahr'].max()) +
                ".jpg",
                bbox_inches='tight',
                pad_inches=0,
                dpi=300)
Exemple #21
0
def plot_state_hinton(state, title='', figsize=None, ax_real=None, ax_imag=None, *, rho=None):
    """Plot a hinton diagram for the density matrix of a quantum state.

    Args:
        state (Statevector or DensityMatrix or ndarray): An N-qubit quantum state.
        title (str): a string that represents the plot title
        figsize (tuple): Figure size in inches.
        ax_real (matplotlib.axes.Axes): An optional Axes object to be used for
            the visualization output. If none is specified a new matplotlib
            Figure will be created and used. If this is specified without an
            ax_imag only the real component plot will be generated.
            Additionally, if specified there will be no returned Figure since
            it is redundant.
        ax_imag (matplotlib.axes.Axes): An optional Axes object to be used for
            the visualization output. If none is specified a new matplotlib
            Figure will be created and used. If this is specified without an
            ax_imag only the real component plot will be generated.
            Additionally, if specified there will be no returned Figure since
            it is redundant.

    Returns:
         matplotlib.Figure:
            The matplotlib.Figure of the visualization if
            neither ax_real or ax_imag is set.

    Raises:
        ImportError: Requires matplotlib.
        VisualizationError: if input is not a valid N-qubit state.

    Example:
        .. jupyter-execute::

            from qiskit import QuantumCircuit
            from qiskit.quantum_info import DensityMatrix
            from qiskit.visualization import plot_state_hinton
            %matplotlib inline

            qc = QuantumCircuit(2)
            qc.h(0)
            qc.cx(0, 1)

            state = DensityMatrix.from_instruction(qc)
            plot_state_hinton(state, title="New Hinton Plot")
    """
    if not HAS_MATPLOTLIB:
        raise ImportError('Must have Matplotlib installed. To install, run '
                          '"pip install matplotlib".')
    # Figure data
    rho = DensityMatrix(state)
    num = rho.num_qubits
    if num is None:
        raise VisualizationError("Input is not a multi-qubit quantum state.")
    max_weight = 2 ** np.ceil(np.log(np.abs(rho.data).max()) / np.log(2))
    datareal = np.real(rho.data)
    dataimag = np.imag(rho.data)

    if figsize is None:
        figsize = (8, 5)
    if not ax_real and not ax_imag:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    else:
        if ax_real:
            fig = ax_real.get_figure()
        else:
            fig = ax_imag.get_figure()
        ax1 = ax_real
        ax2 = ax_imag
    column_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
    row_names = [bin(i)[2:].zfill(num) for i in range(2**num)]
    ly, lx = datareal.shape
    # Real
    if ax1:
        ax1.patch.set_facecolor('gray')
        ax1.set_aspect('equal', 'box')
        ax1.xaxis.set_major_locator(plt.NullLocator())
        ax1.yaxis.set_major_locator(plt.NullLocator())

        for (x, y), w in np.ndenumerate(datareal):
            color = 'white' if w > 0 else 'black'
            size = np.sqrt(np.abs(w) / max_weight)
            rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
                                 facecolor=color, edgecolor=color)
            ax1.add_patch(rect)

        ax1.set_xticks(np.arange(0, lx+0.5, 1))
        ax1.set_yticks(np.arange(0, ly+0.5, 1))
        row_names.append('')
        ax1.set_yticklabels(row_names, fontsize=14)
        column_names.append('')
        ax1.set_xticklabels(column_names, fontsize=14, rotation=90)
        ax1.autoscale_view()
        ax1.invert_yaxis()
        ax1.set_title('Re[$\\rho$]', fontsize=14)
    # Imaginary
    if ax2:
        ax2.patch.set_facecolor('gray')
        ax2.set_aspect('equal', 'box')
        ax2.xaxis.set_major_locator(plt.NullLocator())
        ax2.yaxis.set_major_locator(plt.NullLocator())

        for (x, y), w in np.ndenumerate(dataimag):
            color = 'white' if w > 0 else 'black'
            size = np.sqrt(np.abs(w) / max_weight)
            rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
                                 facecolor=color, edgecolor=color)
            ax2.add_patch(rect)

        ax2.set_xticks(np.arange(0, lx+0.5, 1))
        ax2.set_yticks(np.arange(0, ly+0.5, 1))
        ax2.set_yticklabels(row_names, fontsize=14)
        ax2.set_xticklabels(column_names, fontsize=14, rotation=90)

        ax2.autoscale_view()
        ax2.invert_yaxis()
        ax2.set_title('Im[$\\rho$]', fontsize=14)
    if title:
        fig.suptitle(title, fontsize=16)
    if ax_real is None and ax_imag is None:
        if get_backend() in ['module://ipykernel.pylab.backend_inline',
                             'nbAgg']:
            plt.close(fig)
        return fig
                    label = r'$\sim n^{-2}$', elinewidth = 1, color='0.5')
        ax1.set_xlabel('Number of Proposals $N$')# \n (Step Size = %1.3f)' %StepSize)
    
        # Make the y-axis label, ticks and tick labels match the line color.
        ax1.set_ylabel(r'Variance', color='k')
#        ax1.tick_params('y', colors='k')
        ax1.set_xscale("log")
        ax1.set_yscale("log")
        x1_ticks_labels = [4,8,16,32,64,128,256,512,1024]
        ax1.set_xticks(np.array([4,8,16,32,64,128,256,512,1024]))
#        x1_ticks_labels = [5,10,25,50,100,250,500,1000] #[5,10,20,50,100] 
#        ax1.set_xticks(np.array([5,10,25,50,100,250,500,1000])) # #[5,10,20,50,100] 
        ax1.set_xticklabels(x1_ticks_labels, fontsize=11)
        
#        ax1.tick_params(axis='x',reset=False,which='x')#,length=8,width=2)
        ax1.xaxis.set_minor_locator(plt.NullLocator()) #plt.FixedLocator([4,8,16,32,64,128,256,512,1024]))


        ax1.legend(loc='best', fontsize=8.75)
        ax1.grid(True,which="major",axis='both',linewidth=0.75,color=lighten_color('grey', 0.25))
#        ax1.grid(True,which="major",axis='y',linewidth=0.5,color=lighten_color('grey', 0.25))
        
        ax2 = ax1.twiny()
        ax2.set_xscale("log")
        ax2.set_yscale("log")
        ax2.errorbar(N_Array*NumOfIter, 1.*1e1*(N_Array*NumOfIter)**(-1.), fmt='--', \
                    label = r'$\sim n^{-1}$', elinewidth = 1, color='0.5')
        ax2.errorbar(N_Array*NumOfIter, 1*1e3*(N_Array*NumOfIter)**(-2.), fmt=':', \
                    label = r'$\sim n^{-2}$', elinewidth = 1, color='0.5')
        x2_ticks_labels = np.array([2000, 10000, 100000, 500000])
        ax2.set_xticks(x2_ticks_labels)
def images2d(images=None, second=10, saveable=True, name='images', dtype=None,
                                                            fig_idx=3119362):
    """Display a group of RGB or Greyscale images.

    Parameters
    ----------
    images : numpy.array
        The images.
    second : int
        The display second(s) for the image(s), if saveable is False.
    saveable : boolen
        Save or plot the figure.
    name : a string
        A name to save the image, if saveable is True.
    dtype : None or numpy data type
        The data type for displaying the images.
    fig_idx : int
        matplotlib figure index.

    Examples
    --------
    >>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False)
    >>> tl.visualize.images2d(X_train[0:100,:,:,:], second=10, saveable=False, name='cifar10', dtype=np.uint8, fig_idx=20212)
    """
    # print(images.shape)    # (50000, 32, 32, 3)
    # exit()
    if dtype:
        images = np.asarray(images, dtype=dtype)
    n_mask = images.shape[0]
    n_row = images.shape[1]
    n_col = images.shape[2]
    n_color = images.shape[3]
    row = int(np.sqrt(n_mask))
    col = int(np.ceil(n_mask/row))
    plt.ion()   # active mode
    fig = plt.figure(fig_idx)
    count = 1
    for ir in range(1, row+1):
        for ic in range(1, col+1):
            if count > n_mask:
                break
            a = fig.add_subplot(col, row, count)
            # print(images[:,:,:,count-1].shape, n_row, n_col)   # (5, 1, 32) 5 5
            # plt.imshow(
            #         np.reshape(images[count-1,:,:,:], (n_row, n_col)),
            #         cmap='gray', interpolation="nearest")     # theano
            if n_color == 1:
                plt.imshow(
                        np.reshape(images[count-1,:,:], (n_row, n_col)),
                        cmap='gray', interpolation="nearest")
            elif n_color == 3:
                plt.imshow(images[count-1,:,:],
                        cmap='gray', interpolation="nearest")
            else:
                raise Exception("Unknown n_color")
            plt.gca().xaxis.set_major_locator(plt.NullLocator())    # distable tick
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            count = count + 1
    if saveable:
        plt.savefig(name+'.pdf',format='pdf')
    else:
        plt.draw()
        plt.pause(second)
Exemple #24
0
def CNN2d(CNN=None, second=10, saveable=True, name='cnn', fig_idx=3119362):
    """Display a group of RGB or Greyscale CNN masks.

    Parameters
    ----------
    CNN : numpy.array
        The image. e.g: 64 5x5 RGB images can be (5, 5, 3, 64).
    second : int
        The display second(s) for the image(s), if saveable is False.
    saveable : boolean
        Save or plot the figure.
    name : str
        A name to save the image, if saveable is True.
    fig_idx : int
        The matplotlib figure index.

    Examples
    --------
    >>> tl.visualize.CNN2d(network.all_params[0].eval(), second=10, saveable=True, name='cnn1_mnist', fig_idx=2012)

    """
    import matplotlib.pyplot as plt
    # logging.info(CNN.shape)    # (5, 5, 3, 64)
    # exit()
    n_mask = CNN.shape[3]
    n_row = CNN.shape[0]
    n_col = CNN.shape[1]
    n_color = CNN.shape[2]
    row = int(np.sqrt(n_mask))
    col = int(np.ceil(n_mask / row))
    plt.ion()  # active mode
    fig = plt.figure(fig_idx)
    count = 1
    for _ir in range(1, row + 1):
        for _ic in range(1, col + 1):
            if count > n_mask:
                break
            fig.add_subplot(col, row, count)
            # logging.info(CNN[:,:,:,count-1].shape, n_row, n_col)   # (5, 1, 32) 5 5
            # exit()
            # plt.imshow(
            #         np.reshape(CNN[count-1,:,:,:], (n_row, n_col)),
            #         cmap='gray', interpolation="nearest")     # theano
            if n_color == 1:
                plt.imshow(np.reshape(CNN[:, :, :, count - 1], (n_row, n_col)),
                           cmap='gray',
                           interpolation="nearest")
            elif n_color == 3:
                plt.imshow(np.reshape(CNN[:, :, :, count - 1],
                                      (n_row, n_col, n_color)),
                           cmap='gray',
                           interpolation="nearest")
            else:
                raise Exception("Unknown n_color")
            plt.gca().xaxis.set_major_locator(
                plt.NullLocator())  # distable tick
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            count = count + 1
    if saveable:
        plt.savefig(name + '.pdf', format='pdf')
    else:
        plt.draw()
        plt.pause(second)
def generate_cover(title, subtitle, file_name):
    sha = hashlib.sha3_512()
    sha.update((title + subtitle).encode("utf8"))
    rem = int.from_bytes(sha.digest(), "little")
    rg = np.random.default_rng(rem)
    r = np.sqrt(1 - 0.5**2)
    x = [0, 0, r, 2 * r, 2 * r, r, 0]
    y = [1, 0, -0.5, 0, 1, 1.5, 1]
    x_data = np.array(x)
    y_data = np.array(y)
    fig = plt.figure(figsize=(66, 36))
    x_orig = x_data.copy()

    width = 37
    height = 25
    for j in range(height):
        x_data = x_orig.copy()
        shift = (j % 2) == 0
        x_data += shift * r
        for i in range(width):
            plt.fill(x_data, y_data, color=color[rg.integers(0, len(color))])
            x_data += 2 * r
        y_data -= 1 + 0.5

    fig.savefig(f"{file_name}-back.png",
                pad_inches=0,
                bbox_inches="tight",
                dpi=108)
    ax = fig.axes[0]
    ylim = ax.get_ylim()
    h = ylim[1] - ylim[0]
    img = Image.open(f"{file_name}-back.png")
    vadj = int(img.size[1] / h // 1)
    px = img.size[0] // (width + 1)
    box = [px // 2, vadj, img.size[0] - px // 2, img.size[1] - vadj]
    crop = img.crop(box)
    sz = crop.size
    if sz[0] / sz[1] > 16 / 9:
        new_width = int(np.ceil(sz[0] * (16 / 9) / (sz[0] / sz[1])))
        loss = sz[0] - new_width
        box = [loss // 2, 0, sz[0] - loss // 2, crop.size[1]]
        crop = crop.crop(box)
    crop.save(f"{file_name}-back.png")

    plt.fill(
        [0, (width + r / 2) * 2 * r, (width + r / 2) * 2 * r, 0],
        [-3 + 0.5 - 1.5, -3 + 0.5 - 1.5, -8 - 2 - 1.5, -8 - 2 - 1.5],
        color="#ffffff",
    )

    plt.subplots_adjust(0, 0, 1, 1, 0, 0)
    for ax in fig.axes:
        ax.set_axis_off()
        ax.axis("off")
        ax.margins(0, 0)
        ax.xaxis.set_major_locator(plt.NullLocator())
        ax.yaxis.set_major_locator(plt.NullLocator())
    ax.text(
        5 * r,
        -6 - 1.5,
        title.replace("`", ""),
        fontsize=3 * 60,
        color=color[1],
        fontweight="normal",
        fontname="Roboto Condensed",
    )
    if subtitle:
        ax.text(
            5 * r,
            -8.5 - 1.5,
            subtitle.replace("`", ""),
            fontsize=3 * 40,
            color=color[3],
            fontname="Roboto Condensed",
            fontweight="light",
        )

    fig.savefig(f"{file_name}-cover.png",
                pad_inches=0,
                bbox_inches="tight",
                dpi=108)

    ylim = ax.get_ylim()
    h = ylim[1] - ylim[0]
    img = Image.open(f"{file_name}-cover.png")
    vadj = int(img.size[1] / h // 1)
    px = img.size[0] // (width + 1)
    box = [px // 2, vadj, img.size[0] - px // 2, img.size[1] - vadj]
    crop = img.crop(box)
    sz = crop.size
    if sz[0] / sz[1] > 16 / 9:
        new_width = int(np.ceil(sz[0] * (16 / 9) / (sz[0] / sz[1])))
        loss = sz[0] - new_width
        box = [loss // 2, 0, sz[0] - loss // 2, crop.size[1]]
        crop = crop.crop(box)
    crop.save(f"{file_name}-cover.png")
Exemple #26
0
def draw_weights(W=None,
                 second=10,
                 saveable=True,
                 shape=None,
                 name='mnist',
                 fig_idx=2396512):
    """Visualize every columns of the weight matrix to a group of Greyscale img.

    Parameters
    ----------
    W : numpy.array
        The weight matrix
    second : int
        The display second(s) for the image(s), if saveable is False.
    saveable : boolean
        Save or plot the figure.
    shape : a list with 2 int or None
        The shape of feature image, MNIST is [28, 80].
    name : a string
        A name to save the image, if saveable is True.
    fig_idx : int
        matplotlib figure index.

    Examples
    --------
    >>> tl.visualize.draw_weights(network.all_params[0].eval(), second=10, saveable=True, name='weight_of_1st_layer', fig_idx=2012)

    """
    if shape is None:
        shape = [28, 28]

    import matplotlib.pyplot as plt
    if saveable is False:
        plt.ion()
    fig = plt.figure(fig_idx)  # show all feature images
    n_units = W.shape[1]

    num_r = int(np.sqrt(n_units))  # 每行显示的个数   若25个hidden unit -> 每行显示5个
    num_c = int(np.ceil(n_units / num_r))
    count = int(1)
    for _row in range(1, num_r + 1):
        for _col in range(1, num_c + 1):
            if count > n_units:
                break
            fig.add_subplot(num_r, num_c, count)
            # ------------------------------------------------------------
            # plt.imshow(np.reshape(W[:,count-1],(28,28)), cmap='gray')
            # ------------------------------------------------------------
            feature = W[:, count - 1] / np.sqrt((W[:, count - 1]**2).sum())
            # feature[feature<0.0001] = 0   # value threshold
            # if count == 1 or count == 2:
            #     print(np.mean(feature))
            # if np.std(feature) < 0.03:      # condition threshold
            #     feature = np.zeros_like(feature)
            # if np.mean(feature) < -0.015:      # condition threshold
            #     feature = np.zeros_like(feature)
            plt.imshow(np.reshape(feature, (shape[0], shape[1])),
                       cmap='gray',
                       interpolation="nearest"
                       )  #, vmin=np.min(feature), vmax=np.max(feature))
            # plt.title(name)
            # ------------------------------------------------------------
            # plt.imshow(np.reshape(W[:,count-1] ,(np.sqrt(size),np.sqrt(size))), cmap='gray', interpolation="nearest")
            plt.gca().xaxis.set_major_locator(
                plt.NullLocator())  # distable tick
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            count = count + 1
    if saveable:
        plt.savefig(name + '.pdf', format='pdf')
    else:
        plt.draw()
        plt.pause(second)
Exemple #27
0
def drawQueryGalary(galaryList, scoreList):

    newObj = open("newgalary.txt", "a+")
    scoreObj = open("score.txt", "a+")

    outputname = os.path.basename(galaryList[0]).split(".")[0] + "_20galary.jpg"

    newGalaryList = []
    for galary in galaryList:
        if galary.startswith("file://"):
            galary = galary.lstrip("file:")

            newGalaryList.append(galary)

        elif galary.startswith("http:"):
            if galary.endswith(".jpg"):
                basename = os.path.basename(galary)
            else:
                basename = os.path.basename(galary) + ".jpg"
            try:

                if not os.path.exists(basename):
                    urllib.urlretrieve(galary.strip(), basename)

                galary = os.getcwd() + "/" + basename

                newGalaryList.append(galary)
                print('galary:', galary)
            except:
                newGalaryList.append(None)
        else:
            if 'black' in galary:
                galary = "images/carfront/" + os.path.basename(galary)
            elif 'vehicle200w' in galary:
                galary = "images/carfront200w/" + os.path.basename(galary)
            else:
                galary = "images/carfront/" + os.path.basename(galary) + ".jpg"
            #galary = "/data/zhouping/algoSdk/code/dgreid/degreid_evaluate/frommatrix/black" + "/" +  galary
            #galary = "/data/zhouping/scripts/reidPc/black/images/carback" + "/" + galary + ".jpg"
            #galary = "/data/zhouping/scripts/reidPc/black/newblack/person/hard" + "/" + re.split("[-_\.]", galary)[0] + "/" + galary + ".jpg"
            #galary = "/data/zhouping/scripts/reidPc/black/images/car" + "/" + re.split("[-_\.]", galary)[0] + "/" + galary + ".jpg"
            #galary = "/data/zhouping/algoSdk/code/dgreid/degreid_evaluate/frommatrix/scripts/images/carfront" + "/" + re.split("[-_\.]", galary)[0] + "/" + galary + ".jpg"
            newGalaryList.append(galary)
            print('galary:', galary)

    #print('newGalaryList:', newGalaryList)
    newObj.write("%s" % newGalaryList)
    newObj.write(os.linesep)
    newObj.close()

    newScoreList = [item for item in scoreList if item is not None]

    scoreObj.write("%s" % newScoreList)
    scoreObj.write(os.linesep)
    scoreObj.close()

    plt.figure()
    fig, ax = plt.subplots(3, 7,figsize=(20, 20))
    fig.suptitle('totalNum:' + str(totalNums[os.path.basename(newGalaryList[0]).split(".jpg")[0].split("_")[0]]), fontsize=20)
    #ax.set(title='totalNum:' + str(os.path.basename(newGalaryList[0]).split(".jpg")[0].split("_")[0]))
    #ax.set_title('totalNum:' + str(os.path.basename(newGalaryList[0]).split(".jpg")[0].split("_")[0]),fontsize=12,color='r')
    #fig.subplots_adjust(hspace=0, wspace=0)
    #fig.tight_layout()

    for i in range(3):
        for j in range(7):

            qBasename = os.path.basename(newGalaryList[0]).split(".jpg")[0].split("_")[0]
            index = 7*i+j
            img = cv2.imread(newGalaryList[index])
            basename = os.path.basename(newGalaryList[index]).split(".jpg")[0].split("_")[0]
            score = scoreList[index]
            if img is None:
                continue

            b,g,r= cv2.split(img)
            img2 = img[:,:,::-1]
            ax[i, j].xaxis.set_major_locator(plt.NullLocator())
            ax[i, j].yaxis.set_major_locator(plt.NullLocator())

            if score is None:
                ax[i, j].text(0, 0, 'Query')
            else:
                score = round(float(score), 3)
                if qBasename != basename:
                    ax[i, j].text(0, 0, str(score), fontsize=20)
                else:
                    ax[i, j].text(0, 0, str(score)+"_"+"y", fontsize=20)
            ax[i,j].imshow(img2,cmap="bone")

    plt.savefig("output/%s" % outputname)
Exemple #28
0
def plot_optimal_policy(pi, feature_mat, filename=None):
    plt.figure()

    ax = plt.axes()
    count = 0
    rows, cols = len(pi), len(pi[0])
    for line in pi:
        for el in line:
            #print("optimal action", el)
            # could be a stochastic policy with more than one optimal action
            for char in el:
                #print(char)
                if char is "^":
                    plot_arrow(count, cols, ax, "up")
                elif char is "v":
                    plot_arrow(count, cols, ax, "down")
                elif char is ">":
                    plot_arrow(count, cols, ax, "right")
                elif char is "<":
                    plot_arrow(count, cols, ax, "left")
                elif char is ".":
                    plot_dot(count, cols, ax)
                elif el is "w":
                    #wall
                    pass
            count += 1

    #use for wall states
    #if walls:
    mat = [[0 if fvec is None else fvec.index(1) + 1 for fvec in row]
           for row in feature_mat]

    #mat =[[0,0],[2,2]]
    feature_set = set()
    for mrow in mat:
        for m in mrow:
            feature_set.add(m)
    num_features = len(feature_set)
    print(mat)
    all_colors = [
        'black', 'white', 'tab:red', 'tab:blue', 'tab:green', 'tab:purple',
        'tab:orange', 'tab:gray', 'tab:cyan'
    ]
    colors_to_use = []
    for f in range(9):  #hard coded to only have 9 features right now
        if f in feature_set:
            colors_to_use.append(all_colors[f])
    cmap = colors.ListedColormap(colors_to_use)
    # else:
    #     mat = [[fvec.index(1) for fvec in row] for row in feature_mat]
    #     cmap = colors.ListedColormap(['white','tab:red','tab:blue','tab:green','tab:purple', 'tab:orange', 'tab:gray', 'tab:cyan'])

    #input()

    #convert feature_mat into colors
    #heatmap =  plt.imshow(mat, cmap="Reds", interpolation='none', aspect='equal')

    im = plt.imshow(mat, cmap=cmap, interpolation='none', aspect='equal')

    ax = plt.gca()

    ax.set_xticks(np.arange(-.5, cols, 1), minor=True)
    ax.set_yticks(np.arange(-.5, rows, 1), minor=True)
    #ax.grid(which='minor', axis='both', linestyle='-', linewidth=5, color='k')
    # Gridlines based on minor ticks
    ax.grid(which='minor', color='k', linestyle='-', linewidth=5)
    ax.xaxis.set_major_formatter(plt.NullFormatter())
    ax.yaxis.set_major_formatter(plt.NullFormatter())
    ax.yaxis.set_major_locator(plt.NullLocator())
    ax.xaxis.set_major_locator(plt.NullLocator())
    #cbar = plt.colorbar(heatmap)
    #cbar.ax.tick_params(labelsize=20)
    plt.tight_layout()
    if not filename:
        plt.show()
    else:
        plt.savefig(filename)
Exemple #29
0
def draw_PIL_image_1(image,
                     ref_boxes,
                     boxes,
                     ref_labels,
                     labels,
                     scores,
                     pm,
                     name,
                     no=None):
    if type(image) != PIL.Image.Image:
        image = F.to_pil_image(image)
    plt.imshow(image)
    plt.axis('off')
    plt.gca().set_axis_off()
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    i = 0
    if no is not None:
        for n in no:
            if i < 1:
                color = 'green'
                x, y = ref_boxes[n][0], ref_boxes[n][1]
                w, h = ref_boxes[n][2] - ref_boxes[n][0], ref_boxes[n][
                    3] - ref_boxes[n][1]
                plt.text(x,
                         y,
                         '{}={}'.format(voc_labels[ref_labels[n] - 1],
                                        round(scores[n].item(), 2)),
                         color='white',
                         verticalalignment='bottom',
                         bbox={
                             'facecolor': color,
                             'alpha': 1.0
                         },
                         fontsize=24)
            else:
                color = 'red'
                x, y = boxes[n][0], boxes[n][1]
                w, h = boxes[n][2] - boxes[n][0], boxes[n][3] - boxes[n][1]
                plt.text(x,
                         y + h,
                         '{}={}'.format(voc_labels[labels[n] - 1],
                                        round(pm[n].item(), 2)),
                         color='white',
                         verticalalignment='bottom',
                         bbox={
                             'facecolor': color,
                             'alpha': 1.0
                         },
                         fontsize=24)
            i += 1
            plt.gca().add_patch(
                plt.Rectangle((x, y),
                              w,
                              h,
                              fill=False,
                              edgecolor=color,
                              linewidth=2.5))
    plt.savefig('fig/{}'.format(name),
                dpi=256,
                bbox_inches='tight',
                pad_inches=0)
    # plt.show()
    plt.cla()
Exemple #30
0
def main():
    parser = argparse.ArgumentParser(description='Visualize eigenvalues and overlaps')
    parser.add_argument('resultsPath', help='Absolute path of the results folders')
    parser.add_argument('outputPath', help='outputPath')
    parser.add_argument('-title', help='title of the plot')
    parser.add_argument('-fileToLookFor_overlap', help='Specify the file with the overlap information')
    parser.add_argument('-fileToLookFor_differencesInRank', help='Specify the file with the differencesInRank information')    
    parser.add_argument('-modes', help='Specify how many modes to plot')
    parser.add_argument('-upperOverlapLimit', help='Upper overlap limit, force manually')
     
    if len(sys.argv)==1:
        parser.print_help()
        sys.exit(1)
    args = parser.parse_args()    
    
    if args.modes:
        modes = int(args.modes)
    else:
        modes = 4
        
    if args.title:
        title = args.title
    else:
        title = ""
        
    if args.outputPath:
        outputPath = args.outputPath
    else:
        outputPath = ""        
        
    fileToLookFor_overlap = "singleModeOverlapsFromSuperset.txt"
    fileToLookFor_differencesInRank = "differencesInRank.txt"
    if args.fileToLookFor_overlap:
        fileToLookFor_overlap = args.fileToLookFor
    if args.fileToLookFor_differencesInRank:
        fileToLookFor_differencesInRank = args.fileToLookFor_differencesInRank        
        
    assert os.path.isdir(args.resultsPath)
    assert os.path.isdir(args.outputPath)


    all340proteinsPaths = glob.glob(args.resultsPath+"*/")
    
    difficults = np.loadtxt("/home/oliwa/workspace/TNMA1/src/BenchmarkAssessmentsOfDifficulty/allinterfaceSuperposed/difficult.txt", dtype="string")
    difficults = set(difficults)    
    
    dataToPlot_overlaps = []
    dataToPlot_differencesInRank = []
    proteins = []
    counter = 0
    
    for proteinPath in sorted(all340proteinsPaths):
        proteinPath = makeStringEndWith(proteinPath, "/")
        
        protein = makeStringNotEndWith(os.path.basename(os.path.normpath(proteinPath)), "/")
        
        if protein not in difficults:
            continue
        counter += 1

        try:
            # load overlap
            overlap = np.loadtxt(proteinPath+fileToLookFor_overlap)
            overlap = overlap[:modes]
            overlap = abs(np.array(overlap))
            overlap = list(overlap)
            if args.upperOverlapLimit:
                for i in range(0, len(overlap)):
                    if overlap[i] > float(args.upperOverlapLimit):
                        overlap[i] = float(args.upperOverlapLimit)
            dataToPlot_overlaps.append(overlap)
            protein = os.path.basename(os.path.normpath(proteinPath))
            proteins.append(protein)
            # load ranking differences
            differenceInRank = np.loadtxt(proteinPath+fileToLookFor_differencesInRank, dtype="int")
            differenceInRank = list(differenceInRank)
            dataToPlot_differencesInRank.append(differenceInRank[:modes])

        except IOError as err:
            print "IOError occurred, probably there is no such file at the path: ", err
            print traceback.format_exc()           
            
            
    print proteins
            
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    #x, y = np.random.rand(2, 100) * 4
    y = range(1, len(proteins)+1)
    x = range(1, modes+1)
    
    xpos, ypos = np.meshgrid(x, y)
    x = xpos.flatten()
    y = ypos.flatten()
    
    colors = []
    print "overlaps len: ", len(dataToPlot_overlaps)
    print "overlaps: ", dataToPlot_overlaps    
    
    dataToPlot_overlaps_flattened = np.array(dataToPlot_overlaps).flatten()
    maxOverlap = max(dataToPlot_overlaps_flattened)
    print "maxOverlap:", maxOverlap
    
    for element in dataToPlot_overlaps_flattened:
        colors.append(plt.cm.jet(element/maxOverlap))
        #print plt.cm.jet(element/maxOverlap)

    print "x", len(x)
    print "y", len(y)
    #print "colors", len(colors)
    
    print "dataToPlot_differencesInRank len: ",dataToPlot_differencesInRank
    dataToPlot_differencesInRank = np.array(dataToPlot_differencesInRank).flatten() + 0.0001
    print "dataToPlot_differencesInRank len: ", len(dataToPlot_differencesInRank.flatten())
    
    dx=np.ones(len(x))*0.5
    dy=dx
    
    p = ax.bar3d(x-0.25, y-0.25, np.zeros(len(x)), dx, dy, dataToPlot_differencesInRank, color=colors, zsort='average')
    
    ax.set_zlim([min(dataToPlot_differencesInRank), max(dataToPlot_differencesInRank)])
    #ax.set_title(title)
    
    # x label for the ascending modes
    #ax.set_xticklabels(range(1, modes+1), minor=False)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    ax.set_xlabel("ascending lambda^R modes") 
    # y label for the proteins
    #ax.set_yticklabels(proteins, minor=False)
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    ax.set_ylabel("proteins")
# #     dataToPlot_overlaps = np.array(dataToPlot_overlaps)      
# #             
# #     fig, ax = plt.subplots(1)
# #     ax.set_yticklabels(proteins, minor=False)
# #     ax.xaxis.tick_top()    
# #     
# #     p = ax.pcolormesh(dataToPlot_overlaps, cmap="bone")
# #     fig.colorbar(p)  
# #     
# #     # put the major ticks at the middle of each cell, notice "reverse" use of dimension
# #     ax.set_yticks(np.arange(dataToPlot_overlaps.shape[0])+0.5, minor=False)
# #     ax.set_xticks(np.arange(dataToPlot_overlaps.shape[1])+0.5, minor=False)   
# #     
# #     # want a more natural, table-like display (sorting)
# #     ax.invert_yaxis()
# #     ax.xaxis.tick_top()
# #     
# #     ax.set_xticklabels(range(1, modes+1), minor=False)
# #     ax.set_yticklabels(proteins, minor=False) 
# #             
# #     if args.title:
# #         plt.title(args.title+"\n\n")
        
    # output
    #outputPath = makeStringEndWith(args.outputPath, "/")+"eigenVis"
    
    #mkdir_p(outputPath)
    plt.savefig(outputPath+'/eigenVis_'+title+'.eps', bbox_inches='tight')
    plt.savefig(outputPath+'/eigenVis_'+title+'.pdf', bbox_inches='tight') 
    #plt.show()
    # close and reset the plot 
    plt.clf()
    plt.cla()
    plt.close()         
    print "total proteins: ", counter