Esempio n. 1
0
 def forward(self, g, pos=None):
     return dgl.mean_nodes(g, 'h')
Esempio n. 2
0
 def forward(self, g, in_feat):
     h = self.conv1(g, in_feat)
     h = F.relu(h)
     h = self.conv2(g, h)
     g.ndata['h'] = h
     return dgl.mean_nodes(g, 'h')
Esempio n. 3
0
    def forward(self, data):
        g = data
        g.ndata['h'] = self.gnn(g)
        #print('repr:',g.ndata['repr'], g.ndata['repr'].shape)
        #assert 0
        g_out = mean_nodes(g, 'repr')
        #print('g_out', g_out.shape)
        #assert 0
        # print(g_out.shape,g.ndata['h'].shape)

        head_ids = (g.ndata['id'] == 1).nonzero().squeeze(1)
        head_embs = g.ndata['repr'][head_ids]

        tail_ids = (g.ndata['id'] == 2).nonzero().squeeze(1)
        tail_embs = g.ndata['repr'][tail_ids]
        #print(g.ndata['idx'][head_ids], g.ndata['idx'][tail_ids],  g.ndata['idx'][tail_ids].shape)
        head_feat = self.drugfeat[g.ndata['idx'][head_ids]]
        tail_feat = self.drugfeat[g.ndata['idx'][tail_ids]]
        #print(head_feat.shape, tail_feat.shape)
        # drug_feat = self.drugfeat[drug_idx]
        # print(drug_feat, drug_feat.shape)
        if self.params.add_feat_emb:
            fuse_feat1 = self.mp_layer2(
                self.relu(
                    self.dropout(
                        self.mp_layer1(
                            head_feat  #torch.cat([head_feat, tail_feat], dim = 1)
                        ))))
            fuse_feat2 = self.mp_layer2(
                self.relu(
                    self.dropout(
                        self.mp_layer1(
                            tail_feat  #torch.cat([head_feat, tail_feat], dim = 1)
                        ))))
            fuse_feat = torch.cat([fuse_feat1, fuse_feat2], dim=1)
        if self.params.add_ht_emb and self.params.add_sb_emb:
            if self.params.add_feat_emb and self.params.add_transe_emb:
                g_rep = torch.cat([
                    g_out.view(-1, (1 + self.params.num_gcn_layers) *
                               self.params.emb_dim),
                    head_embs.view(-1, (1 + self.params.num_gcn_layers) *
                                   self.params.emb_dim),
                    tail_embs.view(-1, (1 + self.params.num_gcn_layers) *
                                   self.params.emb_dim),
                    fuse_feat.view(-1, 2 * self.params.emb_dim)
                ],
                                  dim=1)
            elif self.params.add_feat_emb:
                g_rep = torch.cat([
                    g_out.view(
                        -1,
                        (self.params.num_gcn_layers) * self.params.emb_dim),
                    head_embs.view(
                        -1,
                        (self.params.num_gcn_layers) * self.params.emb_dim),
                    tail_embs.view(
                        -1,
                        (self.params.num_gcn_layers) * self.params.emb_dim),
                    fuse_feat.view(-1, 2 * self.params.emb_dim)
                ],
                                  dim=1)
            else:
                g_rep = torch.cat(
                    [
                        g_out.view(-1, (1 + self.params.num_gcn_layers) *
                                   self.params.emb_dim),
                        head_embs.view(-1, (1 + self.params.num_gcn_layers) *
                                       self.params.emb_dim),
                        tail_embs.view(-1, (1 + self.params.num_gcn_layers) *
                                       self.params.emb_dim),
                        #fuse_feat.view(-1, 2*self.params.emb_dim)
                    ],
                    dim=1)

        elif self.params.add_ht_emb:
            g_rep = torch.cat([
                head_embs.view(
                    -1,
                    (1 + self.params.num_gcn_layers) * self.params.emb_dim),
                tail_embs.view(
                    -1, (1 + self.params.num_gcn_layers) * self.params.emb_dim)
            ],
                              dim=1)
        else:
            g_rep = g_out.view(
                -1, self.params.num_gcn_layers * self.params.emb_dim)
        #print(g_rep.shape, self.params.add_ht_emb, self.params.add_sb_emb)
        output = self.fc_layer(F.dropout(g_rep, p=0.3))
        # print(head_ids.detach().cpu().numpy(), tail_ids.detach().cpu().numpy())
        return output
Esempio n. 4
0
    def forward(self, g, h, e, h_pos_enc=None):
        # g = g.local_var()
        # extra_e = torch.ones(h.size(0), device=e.device, dtype=e.dtype) * self.self_edge_id
        # e = torch.cat([e, extra_e], 0)
        # nodeids = torch.arange(h.size(0), dtype=h.dtype, device=h.device)
        # g.add_edges(nodeids, nodeids, data={"feat": torch.ones_like(nodeids, dtype=torch.long) * self.self_edge_id})
        # # input embedding
        # h = self.embedding_h(h)
        # h = self.in_feat_dropout(h)
        # if self.pos_enc:
        #     h_pos_enc = self.embedding_pos_enc(h_pos_enc.float())
        #     h = h + h_pos_enc
        # if not self.edge_feat:  # edge feature set to 1
        #     e = torch.ones(e.size(0), 1).to(self.device)
        # e = self.embedding_e(e)
        #
        # g.ndata["h"] = h
        # g.edata["emb"] = e

        g = g.local_var()
        gs = dgl.unbatch(g)
        _gs = []
        for gse in gs:
            # extra_e = torch.ones(gse.number_of_nodes(), device=e.device, dtype=e.dtype) * self.self_edge_id
            # e = torch.cat([e, extra_e], 0)
            nodeids = torch.arange(gse.number_of_nodes(),
                                   dtype=h.dtype,
                                   device=h.device)
            gse.add_edges(nodeids,
                          nodeids,
                          data={
                              "feat":
                              torch.ones_like(nodeids, dtype=torch.long) *
                              self.self_edge_id
                          })
            _gs.append(gse)
            # assert torch.all(g.edata["feat"] == e)
        g = dgl.batch(_gs)
        # input embedding
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        if self.pos_enc:
            h_pos_enc = self.embedding_pos_enc(h_pos_enc.float())
            h = h + h_pos_enc
        if not self.edge_feat:  # edge feature set to 1
            raise Exception("not implemented yet")
        #     e = torch.ones(e.size(0), 1).to(self.device)
        # e = self.embedding_e(e)

        g.ndata["h"] = h
        g.edata["id"] = g.edata["feat"]

        # convnets
        for layer in self.layers:
            for _ in range(self.numrepsperlayer):
                g = layer(g)

        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        elif self.readout == "set2set":
            hg = self.set_to_set(g)
        else:
            hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes

        hg = self.dropout(hg)

        return self.MLP_layer(hg)
Esempio n. 5
0
    def forward(self, g, X, E, snorm_n, snorm_e):
        # input embedding
        H = self.embedding_h(X)
        E = self.embedding_e(E)

        # graph convnet layers
        for GGCN_layer in self.NormalGCN_layers:
            H, E = GGCN_layer(g, H, E, snorm_n, snorm_e)

        # MLP classifier
        g.ndata['H'] = H
        y = dgl.mean_nodes(g, 'H')
        y = self.MLP_layer(y)

        return y


# %%

# # instantiate network
# model = NormalGCN(input_dim=1, hidden_dim=150, output_dim=8, L=2)
# print(model)

# %% md

## Define a few helper functions

# %%

# Collate function to prepare graphs
#
# def collate(samples):
#     graphs, labels = map(list, zip(*samples))  # samples is a list of pairs (graph, label)
#     labels = torch.tensor(labels)
#     sizes_n = [graph.number_of_nodes() for graph in graphs]  # graph sizes
#     snorm_n = [torch.FloatTensor(size, 1).fill_(1 / size) for size in sizes_n]
#     snorm_n = torch.cat(snorm_n).sqrt()  # graph size normalization
#     sizes_e = [graph.number_of_edges() for graph in graphs]  # nb of edges
#     snorm_e = [torch.FloatTensor(size, 1).fill_(1 / size) for size in sizes_e]
#     snorm_e = torch.cat(snorm_e).sqrt()  # graph size normalization
#     batched_graph = dgl.batch(graphs)  # batch graphs
#     return batched_graph, labels, snorm_n, snorm_e
#

# %%

# Compute accuracy
#
# def accuracy(logits, targets):
#     preds = logits.detach().argmax(dim=1)
#     acc = (preds == targets).sum().item()
#     return acc

# %% md

## Test forward pass

# %%
#
# # Define DataLoader and get first graph batch
#
# train_loader = DataLoader(trainset, batch_size=10, shuffle=True, collate_fn=collate)
# batch_graphs, batch_labels, batch_snorm_n, batch_snorm_e = next(iter(train_loader))
# batch_X = batch_graphs.ndata['feat']
# batch_E = batch_graphs.edata['feat']
#
# # %%
#
# # Checking some sizes
#
# print(f'batch_graphs:', batch_graphs)
# print(f'batch_labels:', batch_labels)
# print('batch_X size:', batch_X.size())
# print('batch_E size:', batch_E.size())
#
# # %%
#
# batch_scores = model(batch_graphs, batch_X, batch_E, batch_snorm_n, batch_snorm_e)
# print(batch_scores.size())
#
# batch_labels = batch_labels
# print(f'accuracy: {accuracy(batch_scores, batch_labels)}')

# %% md

## Test backward pass

# %%
#
# # Loss
# J = nn.CrossEntropyLoss()(batch_scores, batch_labels.long())
#
# # Backward pass
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# optimizer.zero_grad()
# J.backward()
# optimizer.step()
#

# %% md

## Train one epoch

# %%
#
# def train(model, data_loader, loss, optimizer):
#     model.train()
#     epoch_loss = 0
#     epoch_train_acc = 0
#     nb_data = 0
#     gpu_mem = 0
#
#     for iter, (batch_graphs, batch_labels, batch_snorm_n, batch_snorm_e) in enumerate(data_loader):
#         batch_X = batch_graphs.ndata['feat']
#         batch_E = batch_graphs.edata['feat']
#
#         batch_scores = model(batch_graphs, batch_X, batch_E, batch_snorm_n, batch_snorm_e)
#         J = loss(batch_scores, batch_labels.long())
#         optimizer.zero_grad()
#         J.backward()
#         optimizer.step()
#
#         epoch_loss += J.detach().item()
#         epoch_train_acc += accuracy(batch_scores, batch_labels)
#         nb_data += batch_labels.size(0)
#
#     epoch_loss /= (iter + 1)
#     epoch_train_acc /= nb_data
#
#     return epoch_loss, epoch_train_acc
#

# %% md

## Evaluation

# %%
#
# def evaluate(model, data_loader, loss):
#     model.eval()
#     epoch_test_loss = 0
#     epoch_test_acc = 0
#     nb_data = 0
#
#     with torch.no_grad():
#         for iter, (batch_graphs, batch_labels, batch_snorm_n, batch_snorm_e) in enumerate(data_loader):
#             batch_X = batch_graphs.ndata['feat']
#             batch_E = batch_graphs.edata['feat']
#
#             batch_scores = model(batch_graphs, batch_X, batch_E, batch_snorm_n, batch_snorm_e)
#             J = loss(batch_scores, batch_labels.long())
#
#             epoch_test_loss += J.detach().item()
#             epoch_test_acc += accuracy(batch_scores, batch_labels)
#             nb_data += batch_labels.size(0)
#
#         epoch_test_loss /= (iter + 1)
#         epoch_test_acc /= nb_data
#
#     return epoch_test_loss, epoch_test_acc

# %% md

# Train GNN
#
# sweep_config = {
#     'method': 'grid', #grid, random
#     'metric': {
#       'name': 'loss',
#       'goal': 'minimize'
#     },
#     'parameters': {
#         'epochs': {
#             'values': [20, 50]
#         },
#         'learning_rate': {
#             'values': [3e-4, 3e-5, 1e-5]
#         },
#         'hidden_dim':{
#             'values':[128,256]
#         },
#     }
# }

# sweep_id = wandb.sweep(sweep_config, project="test-sweep")
# %%
#
#
#
# def train_wandb():
#     config_defaults = {
#         'epochs': 25,
#         'learning_rate': 1e-3,
#         'hidden_dim': 128
#     }
#
#     wandb.init(config=config_defaults)
#
#     config = wandb.config
#
#     # datasets
#     train_loader = DataLoader(trainset, batch_size=50, shuffle=True, collate_fn=collate)
#     test_loader = DataLoader(testset, batch_size=50, shuffle=False, collate_fn=collate)
#
#     # Create model
#     model = GatedGCN(input_dim=1, hidden_dim=config.hidden_dim, output_dim=8, L=4)
#     loss = nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
#
#     for epoch in range(100):
#         start = time.time()
#         train_loss, train_acc = train(model, train_loader, loss, optimizer)
#
#         test_loss, test_acc = evaluate(model, test_loader, loss)
#
#         wandb.log({
#             "train_loss": train_loss,
#             "test_loss": test_loss,
#             "train_acc": train_acc,
#             "test_acc": test_acc,
#         })
#
#         print(f'Epoch {epoch}, train_loss: {train_loss:.4f}, test_loss: {test_loss:.4f}')
#         print(f'train_acc: {train_acc:.4f}, test_acc: {test_acc:.4f}')
#%%
# wandb.agent(sweep_id, train_wandb)
# %%
#
# plt.plot(epoch_train_losses, label="train_loss")
# plt.plot(epoch_test_losses, label="test_loss")
# plt.legend()
# plt.figure()
# plt.plot(epoch_train_accs, label="train_acc")
# plt.plot(epoch_test_accs, label="test_acc")
# plt.legend()

# %%
Esempio n. 6
0
	def forward(self, sent_vecs, graph):
		node_embed = self.graph_encoder(graph)
		graph_embed = dgl.mean_nodes(node_embed, 'h')
		concated = torch.cat((sent_vecs, graph_embed), 1)
		logits = self.mlp(concated)
		return logits
 def readout_fn(graphs, h):
     hg_max = dgl.max_nodes(graphs, h)
     hg_mean = dgl.mean_nodes(graphs, h)
     hg = torch.cat([hg_max, hg_mean], dim=-1)
     return hg
Esempio n. 8
0
 def forward(self, g, pos):
     g.ndata['a'] = self.nonlinear(self.position_weights(pos))
     return dgl.mean_nodes(g, 'h', 'a')