def _tcep_all_hyperparams_lineplots_withdir():
    tests = ["mmd-gamma", "c2st-knn", "c2st-nn", "test_loss"]
    x_vals = ["num_hiddens", "max_iter_factor"]
    display = GridDisplay(num_items=len(x_vals)*len(tests), nrows=-1, ncols=2)

    for t,xv in product(tests,x_vals):
        display.add_plot(callback = ( lambda  ax: _tcep_hyperparams_lineplots_withdir(x=xv, test=t, ax=ax)))

    plt.show()
    def plot_training_samples(self, num_displays=6):
        assert self.data_is_set
        self.reset_net()
        if num_displays > 2 & num_displays < 10:
            ncols = 3
        elif num_displays > 10:
            ncols = 5

        display_its = [
            int(t / self.lr)
            for t in torch.linspace(0, self.max_iter_factor, num_displays)
        ]
        display = GridDisplay(num_items=num_displays, nrows=-1, ncols=ncols)

        colors = (10 * self._XY[:, 0]).cos() * (10 * self._XY[:, 1]).cos()
        colors = colors.detach().cpu().numpy()

        for i in range(self.num_iters):
            # set grads to zero, re-fill random noise buffer
            self._opt.zero_grad()
            self.forward()
            # compute loss & backward passes
            loss = self._L(self._XY, self._XY_hat)
            loss.backward()
            self._opt.step()

            if i in display_its:  # display
                print(f'loop display iter #{i}')

                def callback(ax, x_i, y_j, colors):
                    plt.set_cmap("hsv")
                    display_samples(ax, self._XY, [(.55, .55, .95)])
                    display_samples(ax, self._XY_hat, colors)
                    ax.set_title("t = {:1.2f}".format(self.lr * i))

                    plt.xticks([], [])
                    plt.yticks([], [])
                    plt.tight_layout()

                display.add_plot(callback=(
                    lambda ax: callback(ax, self._XY, self._XY_hat, colors)))

        if self.loss == "sinkhorn":
            L_info = {
                "loss": self.loss,
                "p": self.p,
                "blur": self.blur,
                "scaling": self.scaling
            }
        else:
            L_info = {"loss": self.loss, "blur": self.blur}
        display.fig.suptitle(
            (f"Gradient flows with loss {L_info};" +
             f"\n until T = {self.lr*i} (steps of {int(1e03/self.lr)/1e03})"),
            fontsize=10)
        display.fig.tight_layout(rect=[0, 0.03, 1, 0.93])
        plt.show()
示例#3
0
def viz_confounded(save=True):
    SEED = 1020
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    causes = ['gmm', 'subgmm', 'supgmm', 'subsupgmm', 'uniform', 'mixtunif']
    base_noises = ['normal', 'student', 'triangular', 'uniform', 'beta']
    mechanisms = ['spline', 'sigmoidam', 'tanhsum', 'rbfgp']
    anms = [False, True]

    for anm, c, bn_x, bn_y, m_x, m_y in product(anms, causes, base_noises,
                                                base_noises, mechanisms,
                                                mechanisms):
        print(
            f'anm? {anm}, cause: {c}, base_noise: {bn_x,bn_y}, mechanism: {m_x,m_y}'
        )
        DtSpl = ConfoundedDatasetSampler(N=5,
                                         n=1000,
                                         anm=anm,
                                         base_noise=[bn_x, bn_y],
                                         confounder_type=c,
                                         mechanism_type=[m_x, m_y],
                                         with_labels=False)

        display = GridDisplay(num_items=5, nrows=-1, ncols=5)
        for pair in DtSpl:

            def callback(ax, pair):
                ax.scatter(pair[0],
                           pair[1],
                           s=10,
                           facecolor='none',
                           edgecolor='k')
                idx = np.argsort(DtSpl.pSampler.x_sample)
                ax.scatter(DtSpl.pSampler.x_sample[idx],
                           DtSpl.pSampler.y_sample[idx],
                           facecolor='r',
                           s=14,
                           alpha=0.7)

            display.add_plot(callback=(lambda ax: callback(ax, pair)))
        display.fig.suptitle(
            f'Confounded: anm? {anm}, cause: {c}, base_noise: {bn_x,bn_y}, mechanism: {m_x,m_y}',
            fontsize=20)
        display.fig.tight_layout(rect=[0, 0.03, 1, 0.93])
        if save:
            _write_nested(
                f'./tests/data/fcm_examples/pairs/cdf_anm_{anm}_c_{c}_bn_{bn_x}+{bn_y}_m_{m_x}+{m_y}',
                callback=lambda fp: plt.savefig(fp, dpi=70))
            #plt.savefig(f'./data/fcm_examples/pairs/anm_{anm}_c_{c}_bn_{bn}_m_{m}', dpi=40)
        else:
            plt.show()
示例#4
0
def viz_pair(save=True):
    SEED = 1020
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    causes = ['gmm', 'subgmm', 'supgmm', 'subsupgmm', 'uniform', 'mixtunif']
    base_noises = ['normal', 'student', 'triangular', 'uniform', 'beta']
    mechanisms = ['spline', 'sigmoidam', 'tanhsum', 'rbfgp']
    anms = [False, True]

    for anm, c, bn, m in product(anms, causes, base_noises, mechanisms):
        print(f'anm? {anm}, cause: {c}, base_noise: {bn}, mechanism: {m}')
        DtSpl = DatasetSampler(N=5,
                               n=1000,
                               anm=anm,
                               base_noise=bn,
                               cause_type=c,
                               mechanism_type=m,
                               with_labels=False)

        display = GridDisplay(num_items=5, nrows=-1, ncols=5)
        for pair in DtSpl:

            def callback(ax, pair):
                ax.scatter(pair[0],
                           pair[1],
                           s=10,
                           facecolor='none',
                           edgecolor='k')
                idx = np.argsort(pair[0])
                x, y = pair[0][idx], pair[1][idx]
                spl = UnivariateSpline(x, y)
                x_display = np.linspace(x.min(), x.max(), 1000)
                ax.plot(x_display, spl(x_display), 'r--')

            display.add_plot(callback=(lambda ax: callback(ax, pair)))
        display.fig.suptitle(
            f'anm? {anm}, cause: {c}, base_noise: {bn}, mechanism: {m}',
            fontsize=20)
        display.fig.tight_layout(rect=[0, 0.03, 1, 0.93])
        if save:
            _write_nested(
                f'./tests/data/fcm_examples/pairs/anm_{anm}_c_{c}_bn_{bn}_m_{m}',
                callback=lambda fp: plt.savefig(fp, dpi=70))
            #plt.savefig(f'./data/fcm_examples/pairs/anm_{anm}_c_{c}_bn_{bn}_m_{m}', dpi=40)
        else:
            plt.show()
def _grid_plotter_onefunc(x, y, e, slope_item):
    assert type(x) == type(y + e)
    display = GridDisplay(num_items=slope_item._nfuncs, nrows=-1, ncols=3)
    if isinstance(x, torch.Tensor):
        for i in range(1, slope_item._nfuncs):
            print(f'func #{i}')

            def callback(ax, x, y, e, i):
                ax.plot(x.sort().values, y[x.sort().indices], 'k--', lw=2)
                y_slope = slope_item._forward(x)
                ax.plot(x.sort().values,
                        y_slope[x.sort().indices],
                        'b-.',
                        lw=2)

                slope_item._fit_index(x, y + e, i)
                y_i = slope_item._forward_index(x, i)
                ax.plot(x.sort().values,
                        y_i[x.sort().indices],
                        'g-1',
                        lw=1.3,
                        alpha=0.7)
                ax.scatter(x, y + e, facecolor='none', edgecolor='r')
                ax.set_title(
                    _index_to_function(slope_item._nfuncs, del_nan=True)[i])

            display.add_plot(callback=(lambda ax: callback(ax, x, y, e, i)))
    elif isinstance(x, np.ndarray):
        for i in range(1, slope_item._nfuncs):

            def callback(ax, x, y, e, i):
                idx = np.argsort(x)
                ax.plot(x[idx], y[idx], 'k--', lw=2)
                y_slope = slope_item._forward(x)
                ax.plot(x[idx], y_slope[idx], 'b-.', lw=2)

                slope_item._fit_index(x, y + e, i)
                y_i = slope_item._forward_index(x, i)
                ax.plot(x[idx], y_i[idx], 'g-1', lw=1.3, alpha=0.7)
                ax.scatter(x, y + e, facecolor='none', edgecolor='r')
                ax.set_title(
                    _index_to_function(slope_item._nfuncs, del_nan=True)[i])

            display.add_plot(callback=(lambda ax: callback(ax, x, y, e, i)))
    display.fig.suptitle(r'Slope on ANM data $Y= f(X)+N$', fontsize=20)
    display.fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
def _all_synthetic_boxplots():
    dirname = './tests/data/geom_ot/data_lengths/synthetic'
    display_idx = 1 ; num_items = len(list(listdir(dirname)))
    display = GridDisplay(num_items=num_items, nrows=-1, ncols=10)
    # for now can't remove legend from every subplot...
    for f in listdir(dirname):
        data = load_array(dirname+'/'+f)
        def callback(ax,data,f,legend):
            _basic_benchmark_boxplot(data, ax=ax, legend=legend, font_scale=0.5)
            ax.set_title('-'.join(anm_name_parser(f)), fontsize=8)

        if display_idx == num_items:
            display.add_plot(callback=(lambda ax: callback(ax,data,f, legend=True)))
            handles, labels = display.last_ax.get_legend_handles_labels()
            display.fig.legend(handles, labels, loc='lower right', fontsize=14)
            display.last_ax.get_legend().remove()
            print('update legend')
        else:
            display.add_plot(callback=(lambda ax: callback(ax,data,f, legend=False)))
            display_idx +=1

    display.fig.tight_layout(pad=3)
    plt.show()
示例#7
0
def viz_cause(num=10):
    i = 0

    def callback(ax, X, i):
        hist_vals, _ = np.histogram(X, bins='auto', density=True)
        sns.distplot(X, ax=ax, color=f'C{i}')
        low_x, up_x, low_y, up_y = X.min() - X.std(), X.max() + X.std(
        ), 0, hist_vals.max() * 1.07
        plt.axis([low_x, up_x, low_y, up_y])
        plt.xticks([], [])
        plt.yticks([], [])
        plt.tight_layout()

    display = GridDisplay(num_items=10, nrows=-1, ncols=5)
    for i in range(num):
        n = 1000
        s = CauseSampler(sample_size=n)
        X = s.uniform()
        display.add_plot(callback=(lambda ax: callback(ax, X, i)))

    display.fig.suptitle('Uniform', fontsize=20)
    display.fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    display = GridDisplay(num_items=10, nrows=-1, ncols=5)
    for i in range(num):
        n = 1000
        s = CauseSampler(sample_size=n)
        X = s.uniform_mixture()
        display.add_plot(callback=(lambda ax: callback(ax, X, i)))

    display.fig.suptitle('Uniform Mixture', fontsize=20)
    display.fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    display = GridDisplay(num_items=10, nrows=-1, ncols=5)
    for i in range(num):
        n = 1000
        s = CauseSampler(sample_size=n)
        X = s.gaussian_mixture()
        display.add_plot(callback=(lambda ax: callback(ax, X, i)))

    display.fig.suptitle('Gaussian Mixture', fontsize=20)
    display.fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    display = GridDisplay(num_items=10, nrows=-1, ncols=5)
    for i in range(num):
        n = 1000
        s = CauseSampler(sample_size=n)
        X = s.subgaussian_mixture()
        display.add_plot(callback=(lambda ax: callback(ax, X, i)))

    display.fig.suptitle('Sub Gaussian Mixture', fontsize=20)
    display.fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    display = GridDisplay(num_items=10, nrows=-1, ncols=5)
    for i in range(num):
        n = 1000
        s = CauseSampler(sample_size=n)
        X = s.supergaussian_mixture()
        display.add_plot(callback=(lambda ax: callback(ax, X, i)))

    display.fig.suptitle('Super Gaussian Mixture', fontsize=20)
    display.fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

    display = GridDisplay(num_items=10, nrows=-1, ncols=5)
    for i in range(num):
        n = 1000
        s = CauseSampler(sample_size=n)
        X = s.subsupgaussian_mixture()
        display.add_plot(callback=(lambda ax: callback(ax, X, i)))

    display.fig.suptitle('Sub & Super Gaussian Mixture', fontsize=20)
    display.fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
def _grid_plotter_mixed(x, y, e, slope_item):
    assert type(x) == type(y + e)

    combinations = range(1, 2**(slope_f._nfuncs))
    display = GridDisplay(num_items=len(combinations),
                          nrows=-1,
                          ncols=8,
                          rowsize=4,
                          colsize=4)

    if isinstance(x, torch.Tensor):
        for i in range(1, 2**(slope_f._nfuncs)):

            def callback(ax, x, y, e, i):
                bool_idx = _bin_int_as_array(i, slope_f._nfuncs)
                _res_mixed = slope_f._fit_mixed(x, y + e, bool_idx)

                ax.plot(x.sort().values, y[x.sort().indices], 'k--',
                        lw=2)  # true func
                y_slope = slope_item._forward(x)
                ax.plot(x.sort().values,
                        y_slope[x.sort().indices],
                        'b-.',
                        lw=2)  # generic fit

                y_mixed = slope_item._forward_mixed(x, bool_idx)

                ax.plot(x.sort().values,
                        y_mixed[x.sort().indices],
                        'g-1',
                        lw=1.3,
                        alpha=0.7)
                ax.scatter(x, y + e, facecolor='none', edgecolor='r')
                ax.set_title('+'.join(_res_mixed['str_idx']), fontsize=8)

            display.add_plot(callback=(lambda ax: callback(ax, x, y, e, i)))

    elif isinstance(x, np.ndarray):
        for i in range(1, 2**(slope_f._nfuncs)):

            def callback(ax, x, y, e, i):
                bool_idx = _bin_int_as_array(i, slope_f._nfuncs)
                _res_mixed = slope_f._fit_mixed(x, y + e, bool_idx)

                idx = np.argsort(x)
                ax.plot(x[idx], y[idx], 'k--', lw=2)
                y_slope = slope_item._forward(x)
                ax.plot(x[idx], y_slope[idx], 'b-.', lw=2)

                y_mixed = slope_item._forward_mixed(x, bool_idx)

                ax.plot(x[idx], y_mixed[idx], 'g-1', lw=1.3, alpha=0.7)
                ax.scatter(x,
                           y + e,
                           facecolor='none',
                           edgecolor='r',
                           alpha=0.7)
                ax.set_title('+'.join(_res_mixed['str_idx']), fontsize=3)

            display.add_plot(callback=(lambda ax: callback(ax, x, y, e, i)))

    display.fig.suptitle(r'Slope-Mixed on ANM data $Y= f(X)+N$', fontsize=15)
    display.fig.tight_layout(rect=[0, 0.03, 1, 0.92])
    plt.subplots_adjust(left=None,
                        bottom=None,
                        right=None,
                        top=None,
                        wspace=0.2,
                        hspace=0.2)
    plt.show()
示例#9
0
import seaborn as sns


def callback(ax, x, y, name):
    # plt.pause(.01)
    ax.scatter(x, y, s=15, alpha=0.4, c='k')
    #ax.set_title(name, fontsize=1)
    # plt.axis([0,1,0,1])
    plt.xticks([], [])
    plt.yticks([], [])
    plt.tight_layout()


def process(row):
    x, y = (scale(row['A']), scale(row['B']))
    return x, y


data, labels = load_dataset('tuebingen', shuffle=False)
labels = labels.values
cut_num_pairs(data, num_max=5000)
display = GridDisplay(num_items=len(data), ncols=10, nrows=-1)

for i, row in data.iterrows():
    x, y = process(row)
    display.add_plot(callback=(lambda ax: callback(ax, x, y, name=i)))

display.fig.tight_layout()
plt.show()
# display.savefig(f'./tests/data/tcep/tcep_pairs')
low, up = data.min(), data.max()
x = data.type(dtype)

# model

model = GaussianMixture(30, sparsity=1, D=dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-01)

num_iters = 500

loss = np.zeros(num_iters)

# display utilities
display_its = [0, 10, 50, 100, 150, 250, 350, 499]
# default num cols is 4
display = GridDisplay(num_items=len(display_its), nrows=-1, ncols=3)


def callback(ax, iter, dim, low, up):
    plt.pause(.05)
    model.plot(x, ax=ax)
    ax.set_title('Density, iter ' + str(iter))
    #plt.axis("equal")
    if dim == 2:
        if anm_flag:
            linsp = np.linspace(0, 1, 1000)
            y_linsp = mech(linsp)
            y_linsp = (y_linsp - y_linsp.mean(0)) / y_linsp.std(0)
            y_linsp = (y_linsp - y_linsp.min()) / (y_linsp.max() -
                                                   y_linsp.min())
            ax.plot(linsp, y_linsp, 'k--')
示例#11
0
def gradient_flow(loss, lr=.05, loss_info=None):
    """Flows along the gradient of the cost function, using a simple Euler scheme.

    Parameters:
        loss ((x_i,y_j) -> torch float number):
            Real-valued loss function.
        lr (float, default = .05):
            Learning rate, i.e. time step.
    """

    # Parameters for the gradient descent
    total_ = 5
    num_displays = 15
    Nsteps = int(total_ / lr) + 1  # base was 5/lr
    display_its = [
        int(t / lr) for t in torch.linspace(0, total_, num_displays)
    ]
    print(f'display its: {display_its}')
    # Use colors to identify the particles
    colors = (10 * X_i[:, 0]).cos() * (10 * X_i[:, 1]).cos()
    colors = colors.detach().cpu().numpy()

    # Make sure that we won't modify the reference samples
    x_i, y_j = X_i.clone(), Y_j.clone()

    # We're going to perform gradient descent on Loss(α, β)
    # wrt. the positions x_i of the diracs masses that make up α:
    x_i.requires_grad = True
    # try using Adam
    optim = torch.optim.Adam([x_i], lr=lr)

    t_0 = time.time()
    display = GridDisplay(num_items=num_displays, nrows=-1, ncols=5)
    for i in range(Nsteps):  # Euler scheme ===============
        # Compute cost and gradient
        L_αβ = loss(x_i, y_j)
        optim.zero_grad()
        L_αβ.backward()
        #[g]  = torch.autograd.grad(L_αβ, [x_i])
        if i in display_its:  # display
            print(f'loop iter {i}')

            #print(f'check gradient and loss magnitudes: {g.norm().data, L_αβ.data}')
            #print(f'check gradient and loss magnitudes: {x_i.grad.data, L_αβ.data}')
            def callback(ax, x_i, y_j, colors):
                plt.set_cmap("hsv")
                display_samples(ax, y_j, [(.55, .55, .95)])
                display_samples(ax, x_i, colors)
                ax.set_title("t = {:1.2f}".format(lr * i))

                #plt.axis([0,1,0,1])
                #plt.gca().set_aspect('equal', adjustable='box')
                plt.xticks([], [])
                plt.yticks([], [])
                plt.tight_layout()

            display.add_plot(
                callback=(lambda ax: callback(ax, x_i, y_j, colors)))

        # in-place modification of the tensor's values
        #x_i.data -= lr * len(x_i) * g
        optim.step()
    display.fig.suptitle((
        f"Gradient flows with loss {loss_info};" +
        f"\n T = {lr*i}, elapsed time: {int(1e03*(time.time() - t_0)/Nsteps)/1e03}s/it"
        + f"\n on ANM data {c,m,bn}"),
                         fontsize=10)
    display.fig.tight_layout(rect=[0, 0.03, 1, 0.93])