Example #1
0
    def update_mu(self, z, wik, lr_mu, tau_mu, g_index=-1, max_iter=150):
        N, D, M = z.shape + (wik.shape[-1], )
        # if too much gaussian we compute mean for each gaussian separately (To avoid too large memory)
        if (M > 40):
            for g_index in range(M // 40 + (1 if (M % 40 != 0) else 0)):
                from_ = g_index * 40
                to_ = min((g_index + 1) * 40, M)

                zz = z.unsqueeze(1).expand(N, to_ - from_, D)
                self._mu[from_:to_] = pa.barycenter(zz,
                                                    wik[:, from_:to_],
                                                    lr_mu,
                                                    tau_mu,
                                                    max_iter=max_iter,
                                                    verbose=True,
                                                    normed=True).squeeze()
        else:
            if (g_index >= 0):
                self._mu[g_index] = pa.barycenter(z,
                                                  wik[:, g_index],
                                                  lr_mu,
                                                  tau_mu,
                                                  max_iter=max_iter,
                                                  normed=True).squeeze()
            else:
                self._mu = pa.barycenter(z.unsqueeze(1).expand(N, M, D),
                                         wik,
                                         lr_mu,
                                         tau_mu,
                                         max_iter=max_iter,
                                         normed=True).squeeze()
    def _fast_maximisation(self, x, indexes, batch_size=10):
        N, D = x.size(0), x.size(-1)
        start_time = time.time()
        centroids = x.new(self._n_c, x.size(-1))
        barycenter_time = 0
        mask_matrix = torch.zeros(x.size(0), self._n_c).to(x.device)
        for i in range(self._n_c):
            lx = x[indexes == i]
            if (lx.shape[0] <= self._mec):
                lx = x[random.randint(0, len(x) - 1)].unsqueeze(0)
                indexes[random.randint(0, len(x) - 1)] = i
            mask_matrix[indexes == i, i] = 1

        nb_batch = (self._n_c // batch_size) + (1 if (
            (self._n_c % batch_size) != 0) else 0)

        for i in range(nb_batch):
            start_index = i * batch_size
            end_index = min((i + 1) * batch_size, self._n_c)
            weight = mask_matrix[:, start_index:end_index]
            barycenter_start_time = time.time()
            centroids[start_index:end_index] = pa.barycenter(
                x.unsqueeze(1).expand(N, end_index - start_index, D),
                wik=weight,
                normed=True,
                lr=5e-3,
                tau=1e-3,
                verbose=self.verbose)
            barycenter_end_time = time.time()
            barycenter_time += (barycenter_end_time - barycenter_start_time)
        end_time = time.time()
        if (self.verbose):
            print("Fast Maximisation Time ", end_time - start_time)
            print("Cumulate barycenter Time ", barycenter_time)
        return centroids
 def _maximisation(self, x, indexes):
     start_time = time.time()
     centroids = x.new(self._n_c, x.size(-1))
     barycenter_time = 0
     for i in range(self._n_c):
         lx = x[indexes == i]
         if (lx.shape[0] <= self._mec):
             lx = x[random.randint(0, len(x) - 1)].unsqueeze(0)
         barycenter_start_time = time.time()
         centroids[i] = pa.barycenter(lx, normed=True)
         barycenter_end_time = time.time()
         barycenter_time += (barycenter_end_time - barycenter_start_time)
     end_time = time.time()
     if (self.verbose):
         print("Maximisation Time ", end_time - start_time)
         print("Cumulate barycenter Time ", barycenter_time)
     return centroids
Example #4
0
        loss = tree_embedding_criterion(pe_x, pe_y, z=ne,
                                        manifold=manifold).sum()
        tloss += loss.item()
        loss.backward()
        optimizer.step()
    print('Loss value for iteration ', i, ' is ', tloss)

from rcome.function_tools import poincare_alg as pa

weigths = torch.Tensor([[1 if (y in dataset.Y[i]) else 0 for y in range(13)]
                        for i in range(len(X))]).cuda()

barycenters = []
for i in range(13):
    barycenters.append(
        pa.barycenter(model.weight.data, weigths[:, i], verbose=True).cpu())

plot_poincare_disc_embeddings(model.weight.data.cpu().numpy(),
                              labels=dataset.Y,
                              centroids=torch.cat(barycenters),
                              save_folder="LOG/mean",
                              file_name="LFR_hierachical.png")
print(barycenters)
barycenters = []
for i in range(13):
    barycenters.append(
        model.weight.data[weigths[:, i] == 1].mean(0).cpu().unsqueeze(0))

print(barycenters)
plot_poincare_disc_embeddings(model.weight.data.cpu().numpy(),
                              labels=dataset.Y,
Example #5
0
import torch
import os

from rcome.visualisation_tools.plot_tools import plot_poincare_disc_embeddings
from rcome.function_tools import poincare_alg as pa
from matplotlib import pyplot as plt

data_point = (torch.randn(50, 2) / 5 + 0.6)
print(data_point)
norms = data_point.norm(2, -1)
data_point[norms >= 1.] = torch.einsum('ij, i -> ij', data_point[norms >= 1.],
                                       1 / (norms[norms >= 1.] + 1e-5))
data_point.norm(2, -1)

barycenter_hyperbolic = pa.barycenter(data_point, verbose=True)

print(barycenter_hyperbolic)
plot_poincare_disc_embeddings(data_point.numpy(),
                              close=False,
                              save_folder="LOG/mean",
                              file_name="example_mean_hyperbolic.png")

barycenter_euclidean = data_point.mean(0, keepdim=True)

plt.scatter(barycenter_hyperbolic[:, 0],
            barycenter_hyperbolic[:, 1],
            marker='D',
            s=300.,
            c='red',
            label="Hyperbolic")
plt.scatter(barycenter_euclidean[:, 0],