Beispiel #1
0
optimizer_init = rsgd.RSGD(list(node_embedding.parameters()) + list(context_embedding.parameters()), args.initial_lr, manifold=manifold)


# Initialise embedding 
if(args.verbose):
    print("Initialise embedding")

for i in range(20):

    l2 = 0.
    for x, y in dataloader_l2:
        optimizer_init.zero_grad()
        pe_x = node_embedding(memory_transfer(x.long()))
        pe_y = context_embedding(memory_transfer(y.long()))
        ne = context_embedding(memory_transfer(distribution.sample(sample_shape=(len(x), args.n_negative)))).detach()
        loss = args.beta * graph_embedding_criterion(pe_x, pe_y, z=ne, manifold=manifold).sum()
        l2 += loss.item()
        loss.backward()
        optimizer_init.step()

    l1 = 0.
    for x, y in dataloader_l1:

        optimizer_init.zero_grad()
        pe_x = memory_transfer(node_embedding(x.long()))
        pe_y = memory_transfer(node_embedding(y.long()))
        loss = args.alpha * graph_embedding_criterion(pe_x, pe_y, manifold=manifold).sum()
        l1 += loss.item()
        loss.backward()
        optimizer_init.step()
                        batch_size=5,
                        shuffle=True,
                        collate_fn=collate_fn_simple)

model = nn.Embedding(len(X), 2, max_norm=0.999)
model.weight.data[:] = model.weight.data * 1e-2

manifold = PoincareBallExact
optimizer = rsgd.RSGD(model.parameters(), 1e-1, manifold=manifold)
default_gt = torch.zeros(20).long()
criterion = nn.CrossEntropyLoss(reduction="sum")

for i in range(50):
    tloss = 0.
    for x, y in dataloader:
        optimizer.zero_grad()
        pe_x = model(x.long())
        pe_y = model(y.long())
        ne = model((torch.rand(len(x), 10) * len(X)).long()).detach()
        loss = graph_embedding_criterion(pe_x, pe_y, z=ne,
                                         manifold=manifold).sum()
        tloss += loss.item()
        loss.backward()
        optimizer.step()
    print('Loss value for iteration ', i, ' is ', tloss)

plot_poincare_disc_embeddings(model.weight.data.numpy(),
                              labels=dataset.Y,
                              save_folder="LOG/second_order",
                              file_name="LFR_second_order_joint.png")