コード例 #1
0
def plot_isocontours_expected(ax, model, data, xlimits=[-6, 6], ylimits=[-6, 6],
                     numticks=101, cmap=None, alpha=1., legend=False):
    x = np.linspace(*xlimits, num=numticks)
    y = np.linspace(*ylimits, num=numticks)
    X, Y = np.meshgrid(x, y)
    # zs = np.exp(func(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T))
    aaa = torch.from_numpy(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T).type(torch.FloatTensor)

    n_samps = 10
    if len(data) < n_samps:
        n_samps = len(data)


    for samp_i in range(n_samps):
        if samp_i % 1000 == 0:
            print samp_i
        mean, logvar = model.encode(Variable(torch.unsqueeze(data[samp_i],0)))
        func = lambda zs: lognormal4(torch.Tensor(zs), torch.squeeze(mean.data), torch.squeeze(logvar.data))
        # print aaa.size()
        bbb = func(aaa)

        # print 'sum:1', torch.sum(bbb)
        ddd = torch.exp(bbb)

        # print 'sum:', torch.sum(ddd)
        # print ddd.size()



        # fdsa

        if samp_i ==0:
            sum_of_all = ddd
        else:
            sum_of_all = sum_of_all + ddd

    avg_of_all = sum_of_all / n_samps

    Z = avg_of_all.view(X.shape)
    Z=Z.numpy()

    # print 'sum:', np.sum(Z)

    cs = plt.contour(X, Y, Z, cmap=cmap, alpha=alpha)


    if legend:
        nm, lbl = cs.legend_elements()
        plt.legend(nm, lbl, fontsize=4) 


    ax.set_yticks([])
    ax.set_xticks([])
    plt.gca().set_aspect('equal', adjustable='box')


    return Z
コード例 #2
0
def plot_isocontours_expected_norm_ind(ax, model, data, xlimits=[-6, 6], ylimits=[-6, 6],
                     numticks=101, cmap=None, alpha=1., legend=False, n_samps=10, cs_to_use=None):
    x = np.linspace(*xlimits, num=numticks)
    y = np.linspace(*ylimits, num=numticks)
    X, Y = np.meshgrid(x, y)
    # zs = np.exp(func(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T))
    aaa = torch.from_numpy(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T).type(torch.FloatTensor)

    # n_samps = 10
    if len(data) < n_samps:
        n_samps = len(data)


    for samp_i in range(n_samps):
        if samp_i % 1000 == 0:
            print samp_i
        mean, logvar = model.encode(Variable(torch.unsqueeze(data[samp_i],0)))
        func = lambda zs: lognormal4(torch.Tensor(zs), torch.squeeze(mean.data), torch.squeeze(logvar.data))
        # print aaa.size()
        bbb = func(aaa)

        zs = bbb.numpy()
        max_ = np.max(zs)
        zs_sum = np.log(np.sum(np.exp(zs-max_))) + max_
        zs = zs - zs_sum
        ddd = np.exp(zs)
        Z = ddd
        Z = Z.reshape(X.shape)
        if cs_to_use != None:
            cs = plt.contour(X, Y, Z, cmap=cmap, alpha=alpha, levels=cs_to_use.levels)
        else:
            cs = plt.contour(X, Y, Z, cmap=cmap, alpha=alpha)


    #     if samp_i ==0:
    #         sum_of_all = ddd
    #     else:
    #         sum_of_all = sum_of_all + ddd

    # avg_of_all = sum_of_all / n_samps

    # Z = avg_of_all.reshape(X.shape)
    # print 'sum:', np.sum(Z)


    # if legend:
    #     nm, lbl = cs.legend_elements()
    #     plt.legend(nm, lbl, fontsize=4) 


    ax.set_yticks([])
    ax.set_xticks([])
    plt.gca().set_aspect('equal', adjustable='box')


    return Z, cs
コード例 #3
0
        # plot_isocontours2_exp_norm(ax, func, cmap='Greys', legend=legend,xlimits=xlimits,ylimits=ylimits,alpha=.2)
        plot_isocontours2_exp_norm(ax,
                                   func,
                                   cmap='Blues',
                                   legend=legend,
                                   xlimits=xlimits,
                                   ylimits=ylimits,
                                   alpha=1.)

        # plot_scatter(ax, samps=z ,xlimits=xlimits,ylimits=ylimits)
        # plot_kde(ax,samps=z,xlimits=xlimits,ylimits=ylimits,cmap='Blues')
        # plot_kde(ax,samps=z,xlimits=xlimits,ylimits=ylimits,cmap='Greens')

        mean, logvar = model.q_dist.get_mean_logvar(samp_torch)
        func = lambda zs: lognormal4(torch.Tensor(zs),
                                     torch.squeeze(mean.data.cpu()),
                                     torch.squeeze(logvar.data.cpu()))
        plot_isocontours(ax,
                         func,
                         cmap='Greens',
                         xlimits=xlimits,
                         ylimits=ylimits)

        # func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2))
        # plot_isocontours(ax, func, cmap='Blues', alpha=.3,xlimits=xlimits,ylimits=ylimits)

        # #Plot prob
        # col +=1
        # ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)
        # Ws, logpW, logqW = model.sample_W()  #_ , [1], [1]
        # # func = lambda zs: lognormal4(torch.Tensor(zs), torch.squeeze(mean.data), torch.squeeze(logvar.data))
コード例 #4
0

        #Plot prob
        row +=1
        ax = plt.subplot2grid((rows,cols), (row, samp_i+1), frameon=False)

        func = lambda zs: model.logposterior_func(samp_torch,zs)
        # plot_isocontours2_exp_norm(ax, func, cmap='Greys', legend=legend,xlimits=xlimits,ylimits=ylimits,alpha=.2)
        plot_isocontours2_exp_norm(ax, func, cmap='Blues', legend=legend,xlimits=xlimits,ylimits=ylimits,alpha=1.)

        # plot_scatter(ax, samps=z ,xlimits=xlimits,ylimits=ylimits)
        # plot_kde(ax,samps=z,xlimits=xlimits,ylimits=ylimits,cmap='Blues')
        # plot_kde(ax,samps=z,xlimits=xlimits,ylimits=ylimits,cmap='Greens')

        mean, logvar = model.q_dist.get_mean_logvar(samp_torch)
        func = lambda zs: lognormal4(torch.Tensor(zs), torch.squeeze(mean.data.cpu()), torch.squeeze(logvar.data.cpu()))
        plot_isocontours(ax, func, cmap='Greens',xlimits=xlimits,ylimits=ylimits)



        # func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2))
        # plot_isocontours(ax, func, cmap='Blues', alpha=.3,xlimits=xlimits,ylimits=ylimits)




        # #Plot prob
        # col +=1
        # ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)
        # Ws, logpW, logqW = model.sample_W()  #_ , [1], [1]   
        # # func = lambda zs: lognormal4(torch.Tensor(zs), torch.squeeze(mean.data), torch.squeeze(logvar.data))
コード例 #5
0
ファイル: bvae_pytorch3.py プロジェクト: ywa136/Other_Code
        model.load_state_dict(
            torch.load(path_to_save_variables, lambda storage, loc: storage))
        print 'loaded variables ' + path_to_save_variables

        rows = 1
        cols = 2

        legend = False
        # legend=True

        n_samps = 10000
        alpha = .3

        fig = plt.figure(figsize=(4 + cols, 4 + rows), facecolor='white')

        prior_func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2),
                                           torch.zeros(2))

        for samp_i in range(rows):

            #Get a sample
            samp = train_x[samp_i]
            # print samp.shape
            col = 0

            # #Plot sample
            # ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)
            # ax.imshow(samp.numpy().reshape(28, 28), vmin=0, vmax=1, cmap="gray")
            # ax.set_yticks([])
            # ax.set_xticks([])
            # if samp_i==0:  ax.annotate('Sample', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction')
コード例 #6
0
            if samp_i==0:  ax.annotate('Sample', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction')


            # #Plot prior
            # col +=1
            # ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)
            # func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2))
            # plot_isocontours(ax, func, cmap='Blues')
            # if samp_i==0:  ax.annotate('Prior p(z)', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction')

            #Plot q
            col +=1
            val = 3
            ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)
            mean, logvar = model.encode(Variable(torch.unsqueeze(samp,0)))
            func = lambda zs: lognormal4(torch.Tensor(zs), torch.squeeze(mean.data), torch.squeeze(logvar.data))
            plot_isocontours(ax, func, cmap='Reds', xlimits=[-val, val], ylimits=[-val, val])
            if samp_i==0:  ax.annotate('p(z)\nq(z|x)\np(z|x)', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction')
            func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2))
            plot_isocontours(ax, func, cmap='Blues', alpha=.3, xlimits=[-val, val], ylimits=[-val, val])
            
            Ws, logpW, logqW = model.sample_W()  #_ , [1], [1]   
            func = lambda zs: log_bernoulli(model.decode(Ws, Variable(torch.unsqueeze(zs,1))), Variable(torch.unsqueeze(samp,0)))+ Variable(torch.unsqueeze(lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2)), 1))
            plot_isocontours2_exp_norm(ax, func, cmap='Greens', legend=legend, xlimits=[-val, val], ylimits=[-val, val])

            # #Plot logprior
            # col +=1
            # ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)
            # func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2))
            # plot_isocontoursNoExp(ax, func, cmap='Blues', legend=legend)
            # if samp_i==0:  ax.annotate('Prior\nlogp(z)', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction')
コード例 #7
0
def plot_isocontours_expected_true_posterior_ind(ax,
                                                 model,
                                                 data,
                                                 xlimits=[-6, 6],
                                                 ylimits=[-6, 6],
                                                 numticks=101,
                                                 cmap=None,
                                                 alpha=1.,
                                                 legend=False,
                                                 n_samps=10,
                                                 cs_to_use=None):
    x = np.linspace(*xlimits, num=numticks)
    y = np.linspace(*ylimits, num=numticks)
    X, Y = np.meshgrid(x, y)
    # zs = np.exp(func(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T))
    aaa = torch.from_numpy(
        np.concatenate([np.atleast_2d(X.ravel()),
                        np.atleast_2d(Y.ravel())]).T).type(torch.FloatTensor)

    # n_samps = n_samps
    if len(data) < n_samps:
        n_samps = len(data)

    for samp_i in range(n_samps):
        # if samp_i % 100 == 0:
        #     print samp_i

        samp = data[samp_i]

        n_Ws = 1
        for i in range(n_Ws):
            # if i % 10 ==0: print i
            # print i
            Ws, logpW, logqW = model.sample_W()  #_ , [1], [1]
            func = lambda zs: log_bernoulli(
                model.decode(Ws, Variable(torch.unsqueeze(zs, 1))),
                Variable(torch.unsqueeze(samp, 0))) + Variable(
                    torch.unsqueeze(
                        lognormal4(torch.Tensor(zs), torch.zeros(2),
                                   torch.zeros(2)), 1))
            bbb = func(aaa)
            zs = bbb.data.numpy()
            max_ = np.max(zs)
            zs_sum = np.log(np.sum(np.exp(zs - max_))) + max_
            zs = zs - zs_sum
            zs = np.exp(zs)

            Z = zs.reshape(X.shape)

            if cs_to_use != None:
                cs = plt.contour(X,
                                 Y,
                                 Z,
                                 cmap=cmap,
                                 alpha=alpha,
                                 levels=cs_to_use.levels)
            else:
                cs = plt.contour(X, Y, Z, cmap=cmap, alpha=alpha)

            # if i ==0:
            #     sum_of_all_i = zs
            # else:
            #     sum_of_all_i = sum_of_all_i + zs

    #     if samp_i ==0:
    #         sum_of_all = sum_of_all_i
    #     else:
    #         sum_of_all = sum_of_all + sum_of_all_i

    # avg_of_all = sum_of_all / n_samps

    # print 'sum:', np.sum(avg_of_all)

    # if legend:
    #     nm, lbl = cs.legend_elements()
    #     plt.legend(nm, lbl, fontsize=4)

    ax.set_yticks([])
    ax.set_xticks([])
    plt.gca().set_aspect('equal', adjustable='box')

    return Z
コード例 #8
0
def plot_isocontours_expected_W(ax,
                                model,
                                samp,
                                xlimits=[-6, 6],
                                ylimits=[-6, 6],
                                numticks=101,
                                cmap=None,
                                alpha=1.,
                                legend=False):
    x = np.linspace(*xlimits, num=numticks)
    y = np.linspace(*ylimits, num=numticks)
    X, Y = np.meshgrid(x, y)
    # zs = np.exp(func(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T))
    aaa = torch.from_numpy(
        np.concatenate([np.atleast_2d(X.ravel()),
                        np.atleast_2d(Y.ravel())]).T).type(torch.FloatTensor)

    n_Ws = 10

    for i in range(n_Ws):
        # if i % 10 ==0: print i

        Ws, logpW, logqW = model.sample_W()  #_ , [1], [1]
        func = lambda zs: log_bernoulli(
            model.decode(Ws, Variable(torch.unsqueeze(zs, 1))),
            Variable(torch.unsqueeze(samp, 0))) + Variable(
                torch.unsqueeze(
                    lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2)
                               ), 1))

        bbb = func(aaa)
        zs = bbb.data.numpy()
        # zs = np.exp(zs/784)

        # print zs.shape
        max_ = np.max(zs)
        # print max_

        zs_sum = np.log(np.sum(np.exp(zs - max_))) + max_

        zs = zs - zs_sum
        zs = np.exp(zs)

        if i == 0:
            sum_of_all = zs
        else:
            sum_of_all = sum_of_all + zs

    avg_of_all = sum_of_all / n_Ws

    Z = avg_of_all.reshape(X.shape)
    # Z = zs.view(X.shape)
    # Z=Z.numpy()
    cs = plt.contour(X, Y, Z, cmap=cmap, alpha=alpha)

    if legend:
        nm, lbl = cs.legend_elements()
        plt.legend(nm, lbl, fontsize=4)

    ax.set_yticks([])
    ax.set_xticks([])
    plt.gca().set_aspect('equal', adjustable='box')
コード例 #9
0
def plot_isocontours_expected_true_posterior_ind(ax, model, data, xlimits=[-6, 6], ylimits=[-6, 6],
                     numticks=101, cmap=None, alpha=1., legend=False, n_samps=10, cs_to_use=None):
    x = np.linspace(*xlimits, num=numticks)
    y = np.linspace(*ylimits, num=numticks)
    X, Y = np.meshgrid(x, y)
    # zs = np.exp(func(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T))
    aaa = torch.from_numpy(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T).type(torch.FloatTensor)


    # n_samps = n_samps
    if len(data) < n_samps:
        n_samps = len(data)


    for samp_i in range(n_samps):
        if samp_i % 100 == 0:
            print samp_i

        samp = data[samp_i]

        n_Ws = 1
        for i in range(n_Ws):
            # if i % 10 ==0: print i
            # print i
            Ws, logpW, logqW = model.sample_W()  #_ , [1], [1]   
            func = lambda zs: log_bernoulli(model.decode(Ws, Variable(torch.unsqueeze(zs,1))), Variable(torch.unsqueeze(samp,0)))+ Variable(torch.unsqueeze(lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2)), 1))
            bbb = func(aaa)
            zs = bbb.data.numpy()
            max_ = np.max(zs)
            zs_sum = np.log(np.sum(np.exp(zs-max_))) + max_
            zs = zs - zs_sum
            zs = np.exp(zs)



            Z = zs.reshape(X.shape)

            if cs_to_use != None:
                cs = plt.contour(X, Y, Z, cmap=cmap, alpha=alpha, levels=cs_to_use.levels)
            else:
                cs = plt.contour(X, Y, Z, cmap=cmap, alpha=alpha)


            # if i ==0:
            #     sum_of_all_i = zs
            # else:
            #     sum_of_all_i = sum_of_all_i + zs



    #     if samp_i ==0:
    #         sum_of_all = sum_of_all_i
    #     else:
    #         sum_of_all = sum_of_all + sum_of_all_i




    # avg_of_all = sum_of_all / n_samps


    # print 'sum:', np.sum(avg_of_all)


    # if legend:
    #     nm, lbl = cs.legend_elements()
    #     plt.legend(nm, lbl, fontsize=4) 

    ax.set_yticks([])
    ax.set_xticks([])
    plt.gca().set_aspect('equal', adjustable='box')

    return Z
コード例 #10
0
def plot_isocontours_expected_W(ax, model, samp, xlimits=[-6, 6], ylimits=[-6, 6],
                     numticks=101, cmap=None, alpha=1., legend=False):
    x = np.linspace(*xlimits, num=numticks)
    y = np.linspace(*ylimits, num=numticks)
    X, Y = np.meshgrid(x, y)
    # zs = np.exp(func(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T))
    aaa = torch.from_numpy(np.concatenate([np.atleast_2d(X.ravel()), np.atleast_2d(Y.ravel())]).T).type(torch.FloatTensor)

    n_Ws = 10

    for i in range(n_Ws):
        if i % 10 ==0: print i

        Ws, logpW, logqW = model.sample_W()  #_ , [1], [1]   
        func = lambda zs: log_bernoulli(model.decode(Ws, Variable(torch.unsqueeze(zs,1))), Variable(torch.unsqueeze(samp,0)))+ Variable(torch.unsqueeze(lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2)), 1))


        bbb = func(aaa)
        zs = bbb.data.numpy()
        # zs = np.exp(zs/784)

        # print zs.shape
        max_ = np.max(zs)
        # print max_

        zs_sum = np.log(np.sum(np.exp(zs-max_))) + max_

        zs = zs - zs_sum
        zs = np.exp(zs)

        if i ==0:
            sum_of_all = zs
        else:
            sum_of_all = sum_of_all + zs



    avg_of_all = sum_of_all / n_Ws

    Z = avg_of_all.reshape(X.shape)
    # Z = zs.view(X.shape)
    # Z=Z.numpy()
    cs = plt.contour(X, Y, Z, cmap=cmap, alpha=alpha)

    if legend:
        nm, lbl = cs.legend_elements()
        plt.legend(nm, lbl, fontsize=4) 


    ax.set_yticks([])
    ax.set_xticks([])
    plt.gca().set_aspect('equal', adjustable='box')
コード例 #11
0
            ax.set_xticks([])
            if samp_i==0:  ax.annotate('Sample', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction')


            # #Plot prior
            # col +=1
            # ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)
            # func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2))
            # plot_isocontours(ax, func, cmap='Blues')
            # if samp_i==0:  ax.annotate('Prior p(z)', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction')

            #Plot q
            col +=1
            ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)
            mean, logvar = model.encode(Variable(torch.unsqueeze(samp,0)))
            func = lambda zs: lognormal4(torch.Tensor(zs), torch.squeeze(mean.data), torch.squeeze(logvar.data))
            plot_isocontours(ax, func, cmap='Reds',xlimits=xlimits,ylimits=ylimits)
            if samp_i==0:  ax.annotate('p(z)\nq(z|x)', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction')
            func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2))
            plot_isocontours(ax, func, cmap='Blues', alpha=.3,xlimits=xlimits,ylimits=ylimits)


            # #Plot logprior
            # col +=1
            # ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)
            # func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2))
            # plot_isocontoursNoExp(ax, func, cmap='Blues', legend=legend)
            # if samp_i==0:  ax.annotate('Prior\nlogp(z)', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction')

            # #Plot logq
            # col +=1
コード例 #12
0
        col +=1
        ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)
        z = model.sample_q(x=samp_torch, k=n_samps)# [P,B,Z]
        z = z.view(-1,z_size)
        z = z.data.cpu().numpy()
        # print (z)

        center_val_x = z[0][0]
        center_val_y = z[0][1]
        xlimits=[center_val_x-lim_val, center_val_x+lim_val]
        ylimits=[center_val_y-lim_val, center_val_y+lim_val]

        plot_scatter(ax, samps=z ,xlimits=xlimits,ylimits=ylimits)


        func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2))
        plot_isocontours(ax, func, cmap='Blues', alpha=.3,xlimits=xlimits,ylimits=ylimits)




        # #Plot q
        # col +=1
        # ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)

        # # # mean, logvar = model.encode(Variable(torch.unsqueeze(samp,0)))


        # samp_torch = Variable(torch.from_numpy(np.array([samp]))).type(model.dtype)

        # # # [P,B,Z]
コード例 #13
0
                            xytext=(.3, 1.1),
                            xy=(0, 1),
                            textcoords='axes fraction')

            # #Plot prior
            # col +=1
            # ax = plt.subplot2grid((rows,cols), (samp_i,col), frameon=False)
            # func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2), torch.zeros(2))
            # plot_isocontours(ax, func, cmap='Blues')
            # if samp_i==0:  ax.annotate('Prior p(z)', xytext=(.3, 1.1), xy=(0, 1), textcoords='axes fraction')

            #Plot q
            col += 1
            ax = plt.subplot2grid((rows, cols), (samp_i, col), frameon=False)
            mean, logvar = model.encode(Variable(torch.unsqueeze(samp, 0)))
            func = lambda zs: lognormal4(torch.Tensor(
                zs), torch.squeeze(mean.data), torch.squeeze(logvar.data))
            plot_isocontours(ax,
                             func,
                             cmap='Reds',
                             xlimits=xlimits,
                             ylimits=ylimits)
            if samp_i == 0:
                ax.annotate('p(z)\nq(z|x)',
                            xytext=(.3, 1.1),
                            xy=(0, 1),
                            textcoords='axes fraction')
            func = lambda zs: lognormal4(torch.Tensor(zs), torch.zeros(2),
                                         torch.zeros(2))
            plot_isocontours(ax,
                             func,
                             cmap='Blues',