class DGILearner: def __init__(self, inp_dim, out_dim, device): self.encoder = DGIEncoderNet(inp_dim, out_dim) self.dgi = DeepGraphInfomax(out_dim, encoder=self.encoder, summary=self.readout, corruption=self.corrupt) self.dgi = self.dgi.to(device) self.optimizer = torch.optim.Adam(self.dgi.parameters()) def embed(self, data): pos_z, _, _ = self.dgi(data.x, data.edge_index, data.edge_attr, msk=None) return pos_z def readout(self, z, x, edge_index, edge_attr, msk=None): if msk is None: return torch.sigmoid(torch.mean(z, 0)) else: return torch.sigmoid(torch.sum(z[msk], 0) / torch.sum(msk)) def corrupt(self, x, edge_index, edge_attr, msk=None): shuffled_rows = torch.randperm(len(x)) shuffled_x = x[shuffled_rows, :] return shuffled_x, edge_index, edge_attr def evaluate_loss(self, data, mode): # use masking for loss evaluation pos_z_train, neg_z_train, summ_train = self.dgi(data.x, data.edge_index, data.edge_attr, msk=data.train_mask) pos_z_test, neg_z_test, summ_test = self.dgi(data.x, data.edge_index, data.edge_attr, msk=data.test_mask) if mode == 'train': return self.dgi.loss(pos_z_train, neg_z_train, summ_train) else: return self.dgi.loss(pos_z_test, neg_z_test, summ_test) def train(self, data): # training self.dgi.train() self.optimizer.zero_grad() loss = self.evaluate_loss(data, mode='train') loss.backward() self.optimizer.step() return loss.item() def test(self, data): # testing self.dgi.eval() return self.evaluate_loss(data, mode='test').item()
def train(): # get the parameters args = get_args() print(args.domain) # decide the device device = torch.device('cuda:2' if torch.cuda.is_available() and args.cuda else 'cpu') # load dataset if args.domain == 'Cora': dataset = Planetoid(root='/home/amax/xsx/data/gnn_datas/Cora', name='Cora', transform=T.NormalizeFeatures()) elif args.domain == 'CiteSeer': dataset = Planetoid(root='/home/amax/xsx/data/gnn_datas/CiteSeer', name='CiteSeer', transform=T.NormalizeFeatures()) elif args.domain == 'PubMed': dataset = Planetoid(root='/home/amax/xsx/data/gnn_datas/PubMed', name='PubMed', transform=T.NormalizeFeatures()) elif args.domain == 'DBLP': dataset = DBLP(root='/home/amax/xsx/data/gnn_datas/DBLP', name='DBLP') elif args.domain == 'Cora-ML': dataset = CoraML(root='/home/amax/xsx/data/gnn_datas/Cora_ML', name='Cora_ML') elif args.domain == 'CS': dataset = Coauthor(root='/home/amax/xsx/data/gnn_datas/Coauthor/CS', name='CS') elif args.domain == 'Physics': dataset = Coauthor(root='/home/amax/xsx/data/gnn_datas/Coauthor/Physics', name='Physics') elif args.domain == 'Computers': dataset = Amazon(root='/home/amax/xsx/data/gnn_datas/Amazon/Computers', name='Computers') elif args.domain == 'Photo': dataset = Amazon(root='/home/amax/xsx/data/gnn_datas/Amazon/Photo', name='Photo') else: dataset = None if dataset is None: pdb.set_trace() data = dataset[0].to(device) # create the model and optimizer model = DeepGraphInfomax(hidden_channels=args.hidden_dim, encoder=Encoder(dataset.num_features, args.hidden_dim), summary=lambda z, *args, **kwargs: z.mean(dim=0), corruption=corruption).to(device) optimizer = Adam(model.parameters(), lr=args.lr) # the information which need to be recorded start_time = time.time() bad_counter = 0 best_epoch = 0 least_loss = float("inf") best_model = None # beging training for epoch in range(args.epochs): # the steps of training model.train() optimizer.zero_grad() pos_z, neg_z, summary = model(data.x, data.edge_index) loss = model.loss(pos_z, neg_z, summary) current_loss = loss.item() loss.backward() optimizer.step() # save the model if it access the minimum loss in current epoch if current_loss < least_loss: least_loss = current_loss best_epoch = epoch + 1 best_model = copy.deepcopy(model) bad_counter = 0 else: bad_counter += 1 # early stop if bad_counter >= args.patience: break print("Optimization Finished!") used_time = time.time() - start_time print("Total epochs: {:2d}".format(best_epoch + 100)) print("Best epochs: {:2d}".format(best_epoch)) # train a classification model node_classification(best_model, data, args, device, int(dataset.num_classes)) print("Total time elapsed: {:.2f}s".format(used_time))
def main_model_dgi(data, hidden, if_all=False): torch.backends.cudnn.deterministic = True device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = DeepGraphInfomax( hidden_channels=hidden, encoder=Encoder(hidden, data), summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)), corruption=corruption) data.split_train_valid() model = model.to(device) data = data.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) best_acc_valid = 0 for epoch in range(10): model.train() optimizer.zero_grad() pos_z, neg_z, summary = model(data.x, data.edge_index) lr = LogisticRegression().fit(pos_z[data.mask_train].detach().cpu().numpy().reshape(-1, hidden), data.y[data.mask_train].cpu().numpy()) valid_pred = lr.predict(pos_z[data.mask_valid].detach().cpu().numpy().reshape(-1, hidden)) acc_valid = accuracy_score(data.y[data.mask_valid].cpu().numpy(), valid_pred) if acc_valid > best_acc_valid: best_acc_valid = acc_valid result = pos_z loss = model.loss(pos_z.to(device), neg_z.to(device), summary.to(device)) loss.backward() optimizer.step() lr = LogisticRegression().fit(result[data.mask_train].detach().cpu().numpy().reshape(-1, hidden), data.y[data.mask_train].cpu().numpy()) train_pred = lr.predict(result[data.mask_train].detach().cpu().numpy().reshape(-1, hidden)) all_pred = lr.predict(result.detach().cpu().numpy().reshape(-1, hidden)) if if_all: return Result( result=torch.tensor(np.eye(data.num_class)[all_pred]).float().cpu(), loss_train=-1, loss_valid=-1, acc_train=accuracy_score(data.y[data.mask_train].cpu().numpy(), train_pred), acc_valid=best_acc_valid, epoch=10, ) else: return Result( result=all_pred[data.mask_test], loss_train=-1, loss_valid=-1, acc_train=accuracy_score(data.y[data.mask_train].cpu().numpy(), train_pred), acc_valid=best_acc_valid, epoch=10, )