def _tcep_all_hyperparams_lineplots_withdir(): tests = ["mmd-gamma", "c2st-knn", "c2st-nn", "test_loss"] x_vals = ["num_hiddens", "max_iter_factor"] display = GridDisplay(num_items=len(x_vals)*len(tests), nrows=-1, ncols=2) for t,xv in product(tests,x_vals): display.add_plot(callback = ( lambda ax: _tcep_hyperparams_lineplots_withdir(x=xv, test=t, ax=ax))) plt.show()
def plot_training_samples(self, num_displays=6): assert self.data_is_set self.reset_net() if num_displays > 2 & num_displays < 10: ncols = 3 elif num_displays > 10: ncols = 5 display_its = [ int(t / self.lr) for t in torch.linspace(0, self.max_iter_factor, num_displays) ] display = GridDisplay(num_items=num_displays, nrows=-1, ncols=ncols) colors = (10 * self._XY[:, 0]).cos() * (10 * self._XY[:, 1]).cos() colors = colors.detach().cpu().numpy() for i in range(self.num_iters): # set grads to zero, re-fill random noise buffer self._opt.zero_grad() self.forward() # compute loss & backward passes loss = self._L(self._XY, self._XY_hat) loss.backward() self._opt.step() if i in display_its: # display print(f'loop display iter #{i}') def callback(ax, x_i, y_j, colors): plt.set_cmap("hsv") display_samples(ax, self._XY, [(.55, .55, .95)]) display_samples(ax, self._XY_hat, colors) ax.set_title("t = {:1.2f}".format(self.lr * i)) plt.xticks([], []) plt.yticks([], []) plt.tight_layout() display.add_plot(callback=( lambda ax: callback(ax, self._XY, self._XY_hat, colors))) if self.loss == "sinkhorn": L_info = { "loss": self.loss, "p": self.p, "blur": self.blur, "scaling": self.scaling } else: L_info = {"loss": self.loss, "blur": self.blur} display.fig.suptitle( (f"Gradient flows with loss {L_info};" + f"\n until T = {self.lr*i} (steps of {int(1e03/self.lr)/1e03})"), fontsize=10) display.fig.tight_layout(rect=[0, 0.03, 1, 0.93]) plt.show()
def viz_confounded(save=True): SEED = 1020 torch.manual_seed(SEED) np.random.seed(SEED) causes = ['gmm', 'subgmm', 'supgmm', 'subsupgmm', 'uniform', 'mixtunif'] base_noises = ['normal', 'student', 'triangular', 'uniform', 'beta'] mechanisms = ['spline', 'sigmoidam', 'tanhsum', 'rbfgp'] anms = [False, True] for anm, c, bn_x, bn_y, m_x, m_y in product(anms, causes, base_noises, base_noises, mechanisms, mechanisms): print( f'anm? {anm}, cause: {c}, base_noise: {bn_x,bn_y}, mechanism: {m_x,m_y}' ) DtSpl = ConfoundedDatasetSampler(N=5, n=1000, anm=anm, base_noise=[bn_x, bn_y], confounder_type=c, mechanism_type=[m_x, m_y], with_labels=False) display = GridDisplay(num_items=5, nrows=-1, ncols=5) for pair in DtSpl: def callback(ax, pair): ax.scatter(pair[0], pair[1], s=10, facecolor='none', edgecolor='k') idx = np.argsort(DtSpl.pSampler.x_sample) ax.scatter(DtSpl.pSampler.x_sample[idx], DtSpl.pSampler.y_sample[idx], facecolor='r', s=14, alpha=0.7) display.add_plot(callback=(lambda ax: callback(ax, pair))) display.fig.suptitle( f'Confounded: anm? {anm}, cause: {c}, base_noise: {bn_x,bn_y}, mechanism: {m_x,m_y}', fontsize=20) display.fig.tight_layout(rect=[0, 0.03, 1, 0.93]) if save: _write_nested( f'./tests/data/fcm_examples/pairs/cdf_anm_{anm}_c_{c}_bn_{bn_x}+{bn_y}_m_{m_x}+{m_y}', callback=lambda fp: plt.savefig(fp, dpi=70)) #plt.savefig(f'./data/fcm_examples/pairs/anm_{anm}_c_{c}_bn_{bn}_m_{m}', dpi=40) else: plt.show()
def viz_pair(save=True): SEED = 1020 torch.manual_seed(SEED) np.random.seed(SEED) causes = ['gmm', 'subgmm', 'supgmm', 'subsupgmm', 'uniform', 'mixtunif'] base_noises = ['normal', 'student', 'triangular', 'uniform', 'beta'] mechanisms = ['spline', 'sigmoidam', 'tanhsum', 'rbfgp'] anms = [False, True] for anm, c, bn, m in product(anms, causes, base_noises, mechanisms): print(f'anm? {anm}, cause: {c}, base_noise: {bn}, mechanism: {m}') DtSpl = DatasetSampler(N=5, n=1000, anm=anm, base_noise=bn, cause_type=c, mechanism_type=m, with_labels=False) display = GridDisplay(num_items=5, nrows=-1, ncols=5) for pair in DtSpl: def callback(ax, pair): ax.scatter(pair[0], pair[1], s=10, facecolor='none', edgecolor='k') idx = np.argsort(pair[0]) x, y = pair[0][idx], pair[1][idx] spl = UnivariateSpline(x, y) x_display = np.linspace(x.min(), x.max(), 1000) ax.plot(x_display, spl(x_display), 'r--') display.add_plot(callback=(lambda ax: callback(ax, pair))) display.fig.suptitle( f'anm? {anm}, cause: {c}, base_noise: {bn}, mechanism: {m}', fontsize=20) display.fig.tight_layout(rect=[0, 0.03, 1, 0.93]) if save: _write_nested( f'./tests/data/fcm_examples/pairs/anm_{anm}_c_{c}_bn_{bn}_m_{m}', callback=lambda fp: plt.savefig(fp, dpi=70)) #plt.savefig(f'./data/fcm_examples/pairs/anm_{anm}_c_{c}_bn_{bn}_m_{m}', dpi=40) else: plt.show()
def _grid_plotter_onefunc(x, y, e, slope_item): assert type(x) == type(y + e) display = GridDisplay(num_items=slope_item._nfuncs, nrows=-1, ncols=3) if isinstance(x, torch.Tensor): for i in range(1, slope_item._nfuncs): print(f'func #{i}') def callback(ax, x, y, e, i): ax.plot(x.sort().values, y[x.sort().indices], 'k--', lw=2) y_slope = slope_item._forward(x) ax.plot(x.sort().values, y_slope[x.sort().indices], 'b-.', lw=2) slope_item._fit_index(x, y + e, i) y_i = slope_item._forward_index(x, i) ax.plot(x.sort().values, y_i[x.sort().indices], 'g-1', lw=1.3, alpha=0.7) ax.scatter(x, y + e, facecolor='none', edgecolor='r') ax.set_title( _index_to_function(slope_item._nfuncs, del_nan=True)[i]) display.add_plot(callback=(lambda ax: callback(ax, x, y, e, i))) elif isinstance(x, np.ndarray): for i in range(1, slope_item._nfuncs): def callback(ax, x, y, e, i): idx = np.argsort(x) ax.plot(x[idx], y[idx], 'k--', lw=2) y_slope = slope_item._forward(x) ax.plot(x[idx], y_slope[idx], 'b-.', lw=2) slope_item._fit_index(x, y + e, i) y_i = slope_item._forward_index(x, i) ax.plot(x[idx], y_i[idx], 'g-1', lw=1.3, alpha=0.7) ax.scatter(x, y + e, facecolor='none', edgecolor='r') ax.set_title( _index_to_function(slope_item._nfuncs, del_nan=True)[i]) display.add_plot(callback=(lambda ax: callback(ax, x, y, e, i))) display.fig.suptitle(r'Slope on ANM data $Y= f(X)+N$', fontsize=20) display.fig.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show()
def _all_synthetic_boxplots(): dirname = './tests/data/geom_ot/data_lengths/synthetic' display_idx = 1 ; num_items = len(list(listdir(dirname))) display = GridDisplay(num_items=num_items, nrows=-1, ncols=10) # for now can't remove legend from every subplot... for f in listdir(dirname): data = load_array(dirname+'/'+f) def callback(ax,data,f,legend): _basic_benchmark_boxplot(data, ax=ax, legend=legend, font_scale=0.5) ax.set_title('-'.join(anm_name_parser(f)), fontsize=8) if display_idx == num_items: display.add_plot(callback=(lambda ax: callback(ax,data,f, legend=True))) handles, labels = display.last_ax.get_legend_handles_labels() display.fig.legend(handles, labels, loc='lower right', fontsize=14) display.last_ax.get_legend().remove() print('update legend') else: display.add_plot(callback=(lambda ax: callback(ax,data,f, legend=False))) display_idx +=1 display.fig.tight_layout(pad=3) plt.show()
def viz_cause(num=10): i = 0 def callback(ax, X, i): hist_vals, _ = np.histogram(X, bins='auto', density=True) sns.distplot(X, ax=ax, color=f'C{i}') low_x, up_x, low_y, up_y = X.min() - X.std(), X.max() + X.std( ), 0, hist_vals.max() * 1.07 plt.axis([low_x, up_x, low_y, up_y]) plt.xticks([], []) plt.yticks([], []) plt.tight_layout() display = GridDisplay(num_items=10, nrows=-1, ncols=5) for i in range(num): n = 1000 s = CauseSampler(sample_size=n) X = s.uniform() display.add_plot(callback=(lambda ax: callback(ax, X, i))) display.fig.suptitle('Uniform', fontsize=20) display.fig.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show() display = GridDisplay(num_items=10, nrows=-1, ncols=5) for i in range(num): n = 1000 s = CauseSampler(sample_size=n) X = s.uniform_mixture() display.add_plot(callback=(lambda ax: callback(ax, X, i))) display.fig.suptitle('Uniform Mixture', fontsize=20) display.fig.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show() display = GridDisplay(num_items=10, nrows=-1, ncols=5) for i in range(num): n = 1000 s = CauseSampler(sample_size=n) X = s.gaussian_mixture() display.add_plot(callback=(lambda ax: callback(ax, X, i))) display.fig.suptitle('Gaussian Mixture', fontsize=20) display.fig.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show() display = GridDisplay(num_items=10, nrows=-1, ncols=5) for i in range(num): n = 1000 s = CauseSampler(sample_size=n) X = s.subgaussian_mixture() display.add_plot(callback=(lambda ax: callback(ax, X, i))) display.fig.suptitle('Sub Gaussian Mixture', fontsize=20) display.fig.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show() display = GridDisplay(num_items=10, nrows=-1, ncols=5) for i in range(num): n = 1000 s = CauseSampler(sample_size=n) X = s.supergaussian_mixture() display.add_plot(callback=(lambda ax: callback(ax, X, i))) display.fig.suptitle('Super Gaussian Mixture', fontsize=20) display.fig.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show() display = GridDisplay(num_items=10, nrows=-1, ncols=5) for i in range(num): n = 1000 s = CauseSampler(sample_size=n) X = s.subsupgaussian_mixture() display.add_plot(callback=(lambda ax: callback(ax, X, i))) display.fig.suptitle('Sub & Super Gaussian Mixture', fontsize=20) display.fig.tight_layout(rect=[0, 0.03, 1, 0.95]) plt.show()
def _grid_plotter_mixed(x, y, e, slope_item): assert type(x) == type(y + e) combinations = range(1, 2**(slope_f._nfuncs)) display = GridDisplay(num_items=len(combinations), nrows=-1, ncols=8, rowsize=4, colsize=4) if isinstance(x, torch.Tensor): for i in range(1, 2**(slope_f._nfuncs)): def callback(ax, x, y, e, i): bool_idx = _bin_int_as_array(i, slope_f._nfuncs) _res_mixed = slope_f._fit_mixed(x, y + e, bool_idx) ax.plot(x.sort().values, y[x.sort().indices], 'k--', lw=2) # true func y_slope = slope_item._forward(x) ax.plot(x.sort().values, y_slope[x.sort().indices], 'b-.', lw=2) # generic fit y_mixed = slope_item._forward_mixed(x, bool_idx) ax.plot(x.sort().values, y_mixed[x.sort().indices], 'g-1', lw=1.3, alpha=0.7) ax.scatter(x, y + e, facecolor='none', edgecolor='r') ax.set_title('+'.join(_res_mixed['str_idx']), fontsize=8) display.add_plot(callback=(lambda ax: callback(ax, x, y, e, i))) elif isinstance(x, np.ndarray): for i in range(1, 2**(slope_f._nfuncs)): def callback(ax, x, y, e, i): bool_idx = _bin_int_as_array(i, slope_f._nfuncs) _res_mixed = slope_f._fit_mixed(x, y + e, bool_idx) idx = np.argsort(x) ax.plot(x[idx], y[idx], 'k--', lw=2) y_slope = slope_item._forward(x) ax.plot(x[idx], y_slope[idx], 'b-.', lw=2) y_mixed = slope_item._forward_mixed(x, bool_idx) ax.plot(x[idx], y_mixed[idx], 'g-1', lw=1.3, alpha=0.7) ax.scatter(x, y + e, facecolor='none', edgecolor='r', alpha=0.7) ax.set_title('+'.join(_res_mixed['str_idx']), fontsize=3) display.add_plot(callback=(lambda ax: callback(ax, x, y, e, i))) display.fig.suptitle(r'Slope-Mixed on ANM data $Y= f(X)+N$', fontsize=15) display.fig.tight_layout(rect=[0, 0.03, 1, 0.92]) plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.2, hspace=0.2) plt.show()
import seaborn as sns def callback(ax, x, y, name): # plt.pause(.01) ax.scatter(x, y, s=15, alpha=0.4, c='k') #ax.set_title(name, fontsize=1) # plt.axis([0,1,0,1]) plt.xticks([], []) plt.yticks([], []) plt.tight_layout() def process(row): x, y = (scale(row['A']), scale(row['B'])) return x, y data, labels = load_dataset('tuebingen', shuffle=False) labels = labels.values cut_num_pairs(data, num_max=5000) display = GridDisplay(num_items=len(data), ncols=10, nrows=-1) for i, row in data.iterrows(): x, y = process(row) display.add_plot(callback=(lambda ax: callback(ax, x, y, name=i))) display.fig.tight_layout() plt.show() # display.savefig(f'./tests/data/tcep/tcep_pairs')
low, up = data.min(), data.max() x = data.type(dtype) # model model = GaussianMixture(30, sparsity=1, D=dim) optimizer = torch.optim.Adam(model.parameters(), lr=1e-01) num_iters = 500 loss = np.zeros(num_iters) # display utilities display_its = [0, 10, 50, 100, 150, 250, 350, 499] # default num cols is 4 display = GridDisplay(num_items=len(display_its), nrows=-1, ncols=3) def callback(ax, iter, dim, low, up): plt.pause(.05) model.plot(x, ax=ax) ax.set_title('Density, iter ' + str(iter)) #plt.axis("equal") if dim == 2: if anm_flag: linsp = np.linspace(0, 1, 1000) y_linsp = mech(linsp) y_linsp = (y_linsp - y_linsp.mean(0)) / y_linsp.std(0) y_linsp = (y_linsp - y_linsp.min()) / (y_linsp.max() - y_linsp.min()) ax.plot(linsp, y_linsp, 'k--')
def gradient_flow(loss, lr=.05, loss_info=None): """Flows along the gradient of the cost function, using a simple Euler scheme. Parameters: loss ((x_i,y_j) -> torch float number): Real-valued loss function. lr (float, default = .05): Learning rate, i.e. time step. """ # Parameters for the gradient descent total_ = 5 num_displays = 15 Nsteps = int(total_ / lr) + 1 # base was 5/lr display_its = [ int(t / lr) for t in torch.linspace(0, total_, num_displays) ] print(f'display its: {display_its}') # Use colors to identify the particles colors = (10 * X_i[:, 0]).cos() * (10 * X_i[:, 1]).cos() colors = colors.detach().cpu().numpy() # Make sure that we won't modify the reference samples x_i, y_j = X_i.clone(), Y_j.clone() # We're going to perform gradient descent on Loss(α, β) # wrt. the positions x_i of the diracs masses that make up α: x_i.requires_grad = True # try using Adam optim = torch.optim.Adam([x_i], lr=lr) t_0 = time.time() display = GridDisplay(num_items=num_displays, nrows=-1, ncols=5) for i in range(Nsteps): # Euler scheme =============== # Compute cost and gradient L_αβ = loss(x_i, y_j) optim.zero_grad() L_αβ.backward() #[g] = torch.autograd.grad(L_αβ, [x_i]) if i in display_its: # display print(f'loop iter {i}') #print(f'check gradient and loss magnitudes: {g.norm().data, L_αβ.data}') #print(f'check gradient and loss magnitudes: {x_i.grad.data, L_αβ.data}') def callback(ax, x_i, y_j, colors): plt.set_cmap("hsv") display_samples(ax, y_j, [(.55, .55, .95)]) display_samples(ax, x_i, colors) ax.set_title("t = {:1.2f}".format(lr * i)) #plt.axis([0,1,0,1]) #plt.gca().set_aspect('equal', adjustable='box') plt.xticks([], []) plt.yticks([], []) plt.tight_layout() display.add_plot( callback=(lambda ax: callback(ax, x_i, y_j, colors))) # in-place modification of the tensor's values #x_i.data -= lr * len(x_i) * g optim.step() display.fig.suptitle(( f"Gradient flows with loss {loss_info};" + f"\n T = {lr*i}, elapsed time: {int(1e03*(time.time() - t_0)/Nsteps)/1e03}s/it" + f"\n on ANM data {c,m,bn}"), fontsize=10) display.fig.tight_layout(rect=[0, 0.03, 1, 0.93])