class DGILearner: def __init__(self, inp_dim, out_dim, device): self.encoder = DGIEncoderNet(inp_dim, out_dim) self.dgi = DeepGraphInfomax(out_dim, encoder=self.encoder, summary=self.readout, corruption=self.corrupt) self.dgi = self.dgi.to(device) self.optimizer = torch.optim.Adam(self.dgi.parameters()) def embed(self, data): pos_z, _, _ = self.dgi(data.x, data.edge_index, data.edge_attr, msk=None) return pos_z def readout(self, z, x, edge_index, edge_attr, msk=None): if msk is None: return torch.sigmoid(torch.mean(z, 0)) else: return torch.sigmoid(torch.sum(z[msk], 0) / torch.sum(msk)) def corrupt(self, x, edge_index, edge_attr, msk=None): shuffled_rows = torch.randperm(len(x)) shuffled_x = x[shuffled_rows, :] return shuffled_x, edge_index, edge_attr def evaluate_loss(self, data, mode): # use masking for loss evaluation pos_z_train, neg_z_train, summ_train = self.dgi(data.x, data.edge_index, data.edge_attr, msk=data.train_mask) pos_z_test, neg_z_test, summ_test = self.dgi(data.x, data.edge_index, data.edge_attr, msk=data.test_mask) if mode == 'train': return self.dgi.loss(pos_z_train, neg_z_train, summ_train) else: return self.dgi.loss(pos_z_test, neg_z_test, summ_test) def train(self, data): # training self.dgi.train() self.optimizer.zero_grad() loss = self.evaluate_loss(data, mode='train') loss.backward() self.optimizer.step() return loss.item() def test(self, data): # testing self.dgi.eval() return self.evaluate_loss(data, mode='test').item()
class Infomax(BaseModel): @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument("--num-features", type=int) parser.add_argument("--num-classes", type=int) parser.add_argument("--hidden-size", type=int, default=512) # fmt: on @classmethod def build_model_from_args(cls, args): return cls( args.num_features, args.num_classes, args.hidden_size, ) def __init__(self, num_features, num_classes, hidden_size): super(Infomax, self).__init__() self.num_features = num_features self.num_classes = num_classes self.hidden_size = hidden_size self.model = DeepGraphInfomax( hidden_channels=hidden_size, encoder=Encoder(num_features, hidden_size), summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)), corruption=corruption, ) def forward(self, x, edge_index): return self.model(x, edge_index) def node_classification_loss(self, data): pos_z, neg_z, summary = self.forward(data.x, data.edge_index) loss = self.model.loss(pos_z, neg_z, summary) return loss def predict(self, data): z, _, _ = self.forward(data.x, data.edge_index) clf = LogisticRegression(solver="lbfgs", multi_class="auto", max_iter=150) clf.fit(z[data.train_mask].detach().cpu().numpy(), data.y[data.train_mask].detach().cpu().numpy()) logits = torch.Tensor(clf.predict_proba(z.detach().cpu().numpy())) if z.is_cuda: logits = logits.cuda() return logits
class InfoMaxModel(BaseModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.save_hyperparameters() self.model = DeepGraphInfomax( hidden_channels=kwargs["hidden_channels"], corruption=corruption, encoder=Encoder(kwargs["num_features"], kwargs["hidden_channels"]), summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0))) def forward(self, batch): pos_z, neg_z, summary = self.model(batch.x, batch.edge_index[0]) loss = self.model.loss(pos_z, neg_z, summary) return pos_z, loss
def test_deep_graph_infomax(): def corruption(z): return z + 1 model = DeepGraphInfomax(hidden_channels=16, encoder=lambda x: x, summary=lambda z, *args: z.mean(dim=0), corruption=lambda x: x + 1) assert model.__repr__() == 'DeepGraphInfomax(16)' x = torch.ones(20, 16) pos_z, neg_z, summary = model(x) assert pos_z.size() == (20, 16) and neg_z.size() == (20, 16) assert summary.size() == (16, ) loss = model.loss(pos_z, neg_z, summary) assert 0 <= loss.item() acc = model.test(torch.ones(20, 16), torch.randint(10, (20, )), torch.ones(20, 16), torch.randint(10, (20, ))) assert 0 <= acc and acc <= 1
def train(): # get the parameters args = get_args() print(args.domain) # decide the device device = torch.device('cuda:2' if torch.cuda.is_available() and args.cuda else 'cpu') # load dataset if args.domain == 'Cora': dataset = Planetoid(root='/home/amax/xsx/data/gnn_datas/Cora', name='Cora', transform=T.NormalizeFeatures()) elif args.domain == 'CiteSeer': dataset = Planetoid(root='/home/amax/xsx/data/gnn_datas/CiteSeer', name='CiteSeer', transform=T.NormalizeFeatures()) elif args.domain == 'PubMed': dataset = Planetoid(root='/home/amax/xsx/data/gnn_datas/PubMed', name='PubMed', transform=T.NormalizeFeatures()) elif args.domain == 'DBLP': dataset = DBLP(root='/home/amax/xsx/data/gnn_datas/DBLP', name='DBLP') elif args.domain == 'Cora-ML': dataset = CoraML(root='/home/amax/xsx/data/gnn_datas/Cora_ML', name='Cora_ML') elif args.domain == 'CS': dataset = Coauthor(root='/home/amax/xsx/data/gnn_datas/Coauthor/CS', name='CS') elif args.domain == 'Physics': dataset = Coauthor(root='/home/amax/xsx/data/gnn_datas/Coauthor/Physics', name='Physics') elif args.domain == 'Computers': dataset = Amazon(root='/home/amax/xsx/data/gnn_datas/Amazon/Computers', name='Computers') elif args.domain == 'Photo': dataset = Amazon(root='/home/amax/xsx/data/gnn_datas/Amazon/Photo', name='Photo') else: dataset = None if dataset is None: pdb.set_trace() data = dataset[0].to(device) # create the model and optimizer model = DeepGraphInfomax(hidden_channels=args.hidden_dim, encoder=Encoder(dataset.num_features, args.hidden_dim), summary=lambda z, *args, **kwargs: z.mean(dim=0), corruption=corruption).to(device) optimizer = Adam(model.parameters(), lr=args.lr) # the information which need to be recorded start_time = time.time() bad_counter = 0 best_epoch = 0 least_loss = float("inf") best_model = None # beging training for epoch in range(args.epochs): # the steps of training model.train() optimizer.zero_grad() pos_z, neg_z, summary = model(data.x, data.edge_index) loss = model.loss(pos_z, neg_z, summary) current_loss = loss.item() loss.backward() optimizer.step() # save the model if it access the minimum loss in current epoch if current_loss < least_loss: least_loss = current_loss best_epoch = epoch + 1 best_model = copy.deepcopy(model) bad_counter = 0 else: bad_counter += 1 # early stop if bad_counter >= args.patience: break print("Optimization Finished!") used_time = time.time() - start_time print("Total epochs: {:2d}".format(best_epoch + 100)) print("Best epochs: {:2d}".format(best_epoch)) # train a classification model node_classification(best_model, data, args, device, int(dataset.num_classes)) print("Total time elapsed: {:.2f}s".format(used_time))
def main_model_dgi(data, hidden, if_all=False): torch.backends.cudnn.deterministic = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = DeepGraphInfomax( hidden_channels=hidden, encoder=Encoder(hidden, data), summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)), corruption=corruption) data.split_train_valid() model = model.to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) best_acc_valid = 0 for epoch in range(10): model.train() optimizer.zero_grad() pos_z, neg_z, summary = model(data.x, data.edge_index) lr = LogisticRegression().fit(pos_z[data.mask_train].detach().cpu().numpy().reshape(-1, hidden), data.y[data.mask_train].cpu().numpy()) valid_pred = lr.predict(pos_z[data.mask_valid].detach().cpu().numpy().reshape(-1, hidden)) acc_valid = accuracy_score(data.y[data.mask_valid].cpu().numpy(), valid_pred) if acc_valid > best_acc_valid: best_acc_valid = acc_valid result = pos_z loss = model.loss(pos_z.to(device), neg_z.to(device), summary.to(device)) loss.backward() optimizer.step() lr = LogisticRegression().fit(result[data.mask_train].detach().cpu().numpy().reshape(-1, hidden), data.y[data.mask_train].cpu().numpy()) train_pred = lr.predict(result[data.mask_train].detach().cpu().numpy().reshape(-1, hidden)) all_pred = lr.predict(result.detach().cpu().numpy().reshape(-1, hidden)) if if_all: return Result( result=torch.tensor(np.eye(data.num_class)[all_pred]).float().cpu(), loss_train=-1, loss_valid=-1, acc_train=accuracy_score(data.y[data.mask_train].cpu().numpy(), train_pred), acc_valid=best_acc_valid, epoch=10, ) else: return Result( result=all_pred[data.mask_test], loss_train=-1, loss_valid=-1, acc_train=accuracy_score(data.y[data.mask_train].cpu().numpy(), train_pred), acc_valid=best_acc_valid, epoch=10, )