def update_mu(self, z, wik, lr_mu, tau_mu, g_index=-1, max_iter=150): N, D, M = z.shape + (wik.shape[-1], ) # if too much gaussian we compute mean for each gaussian separately (To avoid too large memory) if (M > 40): for g_index in range(M // 40 + (1 if (M % 40 != 0) else 0)): from_ = g_index * 40 to_ = min((g_index + 1) * 40, M) zz = z.unsqueeze(1).expand(N, to_ - from_, D) self._mu[from_:to_] = pa.barycenter(zz, wik[:, from_:to_], lr_mu, tau_mu, max_iter=max_iter, verbose=True, normed=True).squeeze() else: if (g_index >= 0): self._mu[g_index] = pa.barycenter(z, wik[:, g_index], lr_mu, tau_mu, max_iter=max_iter, normed=True).squeeze() else: self._mu = pa.barycenter(z.unsqueeze(1).expand(N, M, D), wik, lr_mu, tau_mu, max_iter=max_iter, normed=True).squeeze()
def _fast_maximisation(self, x, indexes, batch_size=10): N, D = x.size(0), x.size(-1) start_time = time.time() centroids = x.new(self._n_c, x.size(-1)) barycenter_time = 0 mask_matrix = torch.zeros(x.size(0), self._n_c).to(x.device) for i in range(self._n_c): lx = x[indexes == i] if (lx.shape[0] <= self._mec): lx = x[random.randint(0, len(x) - 1)].unsqueeze(0) indexes[random.randint(0, len(x) - 1)] = i mask_matrix[indexes == i, i] = 1 nb_batch = (self._n_c // batch_size) + (1 if ( (self._n_c % batch_size) != 0) else 0) for i in range(nb_batch): start_index = i * batch_size end_index = min((i + 1) * batch_size, self._n_c) weight = mask_matrix[:, start_index:end_index] barycenter_start_time = time.time() centroids[start_index:end_index] = pa.barycenter( x.unsqueeze(1).expand(N, end_index - start_index, D), wik=weight, normed=True, lr=5e-3, tau=1e-3, verbose=self.verbose) barycenter_end_time = time.time() barycenter_time += (barycenter_end_time - barycenter_start_time) end_time = time.time() if (self.verbose): print("Fast Maximisation Time ", end_time - start_time) print("Cumulate barycenter Time ", barycenter_time) return centroids
def _maximisation(self, x, indexes): start_time = time.time() centroids = x.new(self._n_c, x.size(-1)) barycenter_time = 0 for i in range(self._n_c): lx = x[indexes == i] if (lx.shape[0] <= self._mec): lx = x[random.randint(0, len(x) - 1)].unsqueeze(0) barycenter_start_time = time.time() centroids[i] = pa.barycenter(lx, normed=True) barycenter_end_time = time.time() barycenter_time += (barycenter_end_time - barycenter_start_time) end_time = time.time() if (self.verbose): print("Maximisation Time ", end_time - start_time) print("Cumulate barycenter Time ", barycenter_time) return centroids
loss = tree_embedding_criterion(pe_x, pe_y, z=ne, manifold=manifold).sum() tloss += loss.item() loss.backward() optimizer.step() print('Loss value for iteration ', i, ' is ', tloss) from rcome.function_tools import poincare_alg as pa weigths = torch.Tensor([[1 if (y in dataset.Y[i]) else 0 for y in range(13)] for i in range(len(X))]).cuda() barycenters = [] for i in range(13): barycenters.append( pa.barycenter(model.weight.data, weigths[:, i], verbose=True).cpu()) plot_poincare_disc_embeddings(model.weight.data.cpu().numpy(), labels=dataset.Y, centroids=torch.cat(barycenters), save_folder="LOG/mean", file_name="LFR_hierachical.png") print(barycenters) barycenters = [] for i in range(13): barycenters.append( model.weight.data[weigths[:, i] == 1].mean(0).cpu().unsqueeze(0)) print(barycenters) plot_poincare_disc_embeddings(model.weight.data.cpu().numpy(), labels=dataset.Y,
import torch import os from rcome.visualisation_tools.plot_tools import plot_poincare_disc_embeddings from rcome.function_tools import poincare_alg as pa from matplotlib import pyplot as plt data_point = (torch.randn(50, 2) / 5 + 0.6) print(data_point) norms = data_point.norm(2, -1) data_point[norms >= 1.] = torch.einsum('ij, i -> ij', data_point[norms >= 1.], 1 / (norms[norms >= 1.] + 1e-5)) data_point.norm(2, -1) barycenter_hyperbolic = pa.barycenter(data_point, verbose=True) print(barycenter_hyperbolic) plot_poincare_disc_embeddings(data_point.numpy(), close=False, save_folder="LOG/mean", file_name="example_mean_hyperbolic.png") barycenter_euclidean = data_point.mean(0, keepdim=True) plt.scatter(barycenter_hyperbolic[:, 0], barycenter_hyperbolic[:, 1], marker='D', s=300., c='red', label="Hyperbolic") plt.scatter(barycenter_euclidean[:, 0],