class DGILearner: def __init__(self, inp_dim, out_dim, device): self.encoder = DGIEncoderNet(inp_dim, out_dim) self.dgi = DeepGraphInfomax(out_dim, encoder=self.encoder, summary=self.readout, corruption=self.corrupt) self.dgi = self.dgi.to(device) self.optimizer = torch.optim.Adam(self.dgi.parameters()) def embed(self, data): pos_z, _, _ = self.dgi(data.x, data.edge_index, data.edge_attr, msk=None) return pos_z def readout(self, z, x, edge_index, edge_attr, msk=None): if msk is None: return torch.sigmoid(torch.mean(z, 0)) else: return torch.sigmoid(torch.sum(z[msk], 0) / torch.sum(msk)) def corrupt(self, x, edge_index, edge_attr, msk=None): shuffled_rows = torch.randperm(len(x)) shuffled_x = x[shuffled_rows, :] return shuffled_x, edge_index, edge_attr def evaluate_loss(self, data, mode): # use masking for loss evaluation pos_z_train, neg_z_train, summ_train = self.dgi(data.x, data.edge_index, data.edge_attr, msk=data.train_mask) pos_z_test, neg_z_test, summ_test = self.dgi(data.x, data.edge_index, data.edge_attr, msk=data.test_mask) if mode == 'train': return self.dgi.loss(pos_z_train, neg_z_train, summ_train) else: return self.dgi.loss(pos_z_test, neg_z_test, summ_test) def train(self, data): # training self.dgi.train() self.optimizer.zero_grad() loss = self.evaluate_loss(data, mode='train') loss.backward() self.optimizer.step() return loss.item() def test(self, data): # testing self.dgi.eval() return self.evaluate_loss(data, mode='test').item()
x = self.convs[i]((x, x_target), edge_index) x = self.activations[i](x) return x def corruption(x, edge_index): return x[torch.randperm(x.size(0))], edge_index device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = DeepGraphInfomax( hidden_channels=512, encoder=Encoder(dataset.num_features, 512), summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)), corruption=corruption).to(device) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) x, y = data.x.to(device), data.y.to(device) def train(epoch): model.train() total_loss = total_examples = 0 for batch_size, n_id, adjs in tqdm(train_loader, desc=f'Epoch {epoch:02d}'): # `adjs` holds a list of `(edge_index, e_id, size)` tuples. adjs = [adj.to(device) for adj in adjs] optimizer.zero_grad()
def main_model_dgi(data, hidden, if_all=False): torch.backends.cudnn.deterministic = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = DeepGraphInfomax( hidden_channels=hidden, encoder=Encoder(hidden, data), summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)), corruption=corruption) data.split_train_valid() model = model.to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) best_acc_valid = 0 for epoch in range(10): model.train() optimizer.zero_grad() pos_z, neg_z, summary = model(data.x, data.edge_index) lr = LogisticRegression().fit(pos_z[data.mask_train].detach().cpu().numpy().reshape(-1, hidden), data.y[data.mask_train].cpu().numpy()) valid_pred = lr.predict(pos_z[data.mask_valid].detach().cpu().numpy().reshape(-1, hidden)) acc_valid = accuracy_score(data.y[data.mask_valid].cpu().numpy(), valid_pred) if acc_valid > best_acc_valid: best_acc_valid = acc_valid result = pos_z loss = model.loss(pos_z.to(device), neg_z.to(device), summary.to(device)) loss.backward() optimizer.step() lr = LogisticRegression().fit(result[data.mask_train].detach().cpu().numpy().reshape(-1, hidden), data.y[data.mask_train].cpu().numpy()) train_pred = lr.predict(result[data.mask_train].detach().cpu().numpy().reshape(-1, hidden)) all_pred = lr.predict(result.detach().cpu().numpy().reshape(-1, hidden)) if if_all: return Result( result=torch.tensor(np.eye(data.num_class)[all_pred]).float().cpu(), loss_train=-1, loss_valid=-1, acc_train=accuracy_score(data.y[data.mask_train].cpu().numpy(), train_pred), acc_valid=best_acc_valid, epoch=10, ) else: return Result( result=all_pred[data.mask_test], loss_train=-1, loss_valid=-1, acc_train=accuracy_score(data.y[data.mask_train].cpu().numpy(), train_pred), acc_valid=best_acc_valid, epoch=10, )