node_embedding.weight.data[:] = node_embedding.weight.data * 1e-2 context_embedding = nn.Embedding(len(X), args.dim, max_norm=0.999) context_embedding.weight.data[:] = context_embedding.weight.data * 1e-2 if(args.cuda): node_embedding.cuda() context_embedding.cuda() memory_transfer = lambda x: x.cuda() else: memory_transfer = lambda x: x if(args.verbose): print("Optimisation and manifold intialisation") manifold = PoincareBallExact optimizer_init = rsgd.RSGD(list(node_embedding.parameters()) + list(context_embedding.parameters()), args.initial_lr, manifold=manifold) # Initialise embedding if(args.verbose): print("Initialise embedding") for i in range(20): l2 = 0. for x, y in dataloader_l2: optimizer_init.zero_grad() pe_x = node_embedding(memory_transfer(x.long())) pe_y = context_embedding(memory_transfer(y.long())) ne = context_embedding(memory_transfer(distribution.sample(sample_shape=(len(x), args.n_negative)))).detach() loss = args.beta * graph_embedding_criterion(pe_x, pe_y, z=ne, manifold=manifold).sum()
from rcome.data_tools import collections from rcome.data_tools.corpora import DatasetTuple from rcome.visualisation_tools.plot_tools import plot_geodesic root_synset = 'mammal.n.01' X, dictionary, values, tuple_neigbhor = collections.animals(root=root_synset) print('Number of nodes', len(X)) dataset = DatasetTuple(X) dataloader = DataLoader(dataset, batch_size=5) model = nn.Embedding(len(dictionary) + 1, 2, max_norm=0.999) model.weight.data[:] = model.weight.data * 1e-3 manifold = PoincareBallExact optimizer = rsgd.RSGD(model.parameters(), 1e-2, manifold=manifold) default_gt = torch.zeros(20).long() criterion = nn.CrossEntropyLoss(reduction="sum") for i in range(5): tloss = 0. for x in dataloader: optimizer.zero_grad() pe = model(x.long()) ne = model((torch.rand(len(x), 2) * len(dictionary)).long() + 1) pd = manifold.distance(pe[:, 0, :], pe[:, 1, :]).unsqueeze(1) nd = manifold.distance(pe[:, 0, :].unsqueeze(1).expand_as(ne), ne) prediction = -torch.cat((pd, nd), 1) loss = criterion(prediction, default_gt[:len(prediction)])
dataset = corpora.RandomContextSizeFlat(X, Y, precompute=2, path_len=10, context_size=3) def collate_fn_simple(my_list): v = torch.cat(my_list,0) return v[:,0], v[:,1] dataloader = DataLoader(dataset, batch_size=5, shuffle=True, collate_fn=collate_fn_simple) model = nn.Embedding(len(X), 2, max_norm=0.999) model.weight.data[:] = model.weight.data * 1e-2 model_context = nn.Embedding(len(X), 2, max_norm=0.999) model_context.weight.data[:] = model.weight.data * 1e-2 manifold = Euclidean optimizer = rsgd.RSGD(list(model.parameters()) + list(model_context.parameters()), 1e-1, manifold=manifold) default_gt = torch.zeros(20).long() # negative sampling distribution frequency = dataset.getFrequency() idx = frequency[:,0].sort()[1] frequency = frequency[idx]**(3/4) frequency[:,1] /= frequency[:,1].sum() distribution = CategoricalDistributionSampler(frequency[:,1]) for i in range(50): tloss = 0. for x, y in dataloader: optimizer.zero_grad()