Example #1
0
class DGILearner:
    def __init__(self, inp_dim, out_dim, device):
        self.encoder = DGIEncoderNet(inp_dim, out_dim)
        self.dgi = DeepGraphInfomax(out_dim, encoder=self.encoder, summary=self.readout, corruption=self.corrupt)
        self.dgi = self.dgi.to(device)

        self.optimizer = torch.optim.Adam(self.dgi.parameters())

    def embed(self, data):
        pos_z, _, _ = self.dgi(data.x, data.edge_index, data.edge_attr, msk=None)
        return pos_z

    def readout(self, z, x, edge_index, edge_attr, msk=None):
        if msk is None:
            return torch.sigmoid(torch.mean(z, 0))
        else:
            return torch.sigmoid(torch.sum(z[msk], 0) / torch.sum(msk))

    def corrupt(self, x, edge_index, edge_attr, msk=None):
        shuffled_rows = torch.randperm(len(x))
        shuffled_x = x[shuffled_rows, :]
        return shuffled_x, edge_index, edge_attr

    def evaluate_loss(self, data, mode):
        # use masking for loss evaluation
        pos_z_train, neg_z_train, summ_train = self.dgi(data.x, data.edge_index, data.edge_attr, msk=data.train_mask)
        pos_z_test, neg_z_test, summ_test = self.dgi(data.x, data.edge_index, data.edge_attr, msk=data.test_mask)

        if mode == 'train':
            return self.dgi.loss(pos_z_train, neg_z_train, summ_train)
        else:
            return self.dgi.loss(pos_z_test, neg_z_test, summ_test)

    def train(self, data):
        # training
        self.dgi.train()
        self.optimizer.zero_grad()
        loss = self.evaluate_loss(data, mode='train')
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def test(self, data):
        # testing
        self.dgi.eval()
        return self.evaluate_loss(data, mode='test').item()
            x = self.convs[i]((x, x_target), edge_index)
            x = self.activations[i](x)
        return x


def corruption(x, edge_index):
    return x[torch.randperm(x.size(0))], edge_index


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DeepGraphInfomax(
    hidden_channels=512, encoder=Encoder(dataset.num_features, 512),
    summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
    corruption=corruption).to(device)

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

x, y = data.x.to(device), data.y.to(device)


def train(epoch):
    model.train()

    total_loss = total_examples = 0
    for batch_size, n_id, adjs in tqdm(train_loader,
                                       desc=f'Epoch {epoch:02d}'):
        # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
        adjs = [adj.to(device) for adj in adjs]

        optimizer.zero_grad()
Example #3
0
def main_model_dgi(data, hidden, if_all=False):
    torch.backends.cudnn.deterministic = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = DeepGraphInfomax(
        hidden_channels=hidden,
        encoder=Encoder(hidden, data),
        summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
        corruption=corruption)

    data.split_train_valid()
    model = model.to(device)
    data = data.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    best_acc_valid = 0
    for epoch in range(10):
        model.train()
        optimizer.zero_grad()
        pos_z, neg_z, summary = model(data.x, data.edge_index)

        lr = LogisticRegression().fit(pos_z[data.mask_train].detach().cpu().numpy().reshape(-1, hidden),
                                      data.y[data.mask_train].cpu().numpy())

        valid_pred = lr.predict(pos_z[data.mask_valid].detach().cpu().numpy().reshape(-1, hidden))
        acc_valid = accuracy_score(data.y[data.mask_valid].cpu().numpy(),
                                   valid_pred)

        if acc_valid > best_acc_valid:
            best_acc_valid = acc_valid
            result = pos_z

        loss = model.loss(pos_z.to(device), neg_z.to(device), summary.to(device))
        loss.backward()
        optimizer.step()

    lr = LogisticRegression().fit(result[data.mask_train].detach().cpu().numpy().reshape(-1, hidden),
                                  data.y[data.mask_train].cpu().numpy())

    train_pred = lr.predict(result[data.mask_train].detach().cpu().numpy().reshape(-1, hidden))
    all_pred = lr.predict(result.detach().cpu().numpy().reshape(-1, hidden))

    if if_all:
        return Result(
            result=torch.tensor(np.eye(data.num_class)[all_pred]).float().cpu(),
            loss_train=-1,
            loss_valid=-1,
            acc_train=accuracy_score(data.y[data.mask_train].cpu().numpy(),
                                     train_pred),
            acc_valid=best_acc_valid,
            epoch=10,
        )
    else:
        return Result(
            result=all_pred[data.mask_test],
            loss_train=-1,
            loss_valid=-1,
            acc_train=accuracy_score(data.y[data.mask_train].cpu().numpy(),
                                     train_pred),
            acc_valid=best_acc_valid,
            epoch=10,
        )