Пример #1
0
class DGILearner:
    def __init__(self, inp_dim, out_dim, device):
        self.encoder = DGIEncoderNet(inp_dim, out_dim)
        self.dgi = DeepGraphInfomax(out_dim, encoder=self.encoder, summary=self.readout, corruption=self.corrupt)
        self.dgi = self.dgi.to(device)

        self.optimizer = torch.optim.Adam(self.dgi.parameters())

    def embed(self, data):
        pos_z, _, _ = self.dgi(data.x, data.edge_index, data.edge_attr, msk=None)
        return pos_z

    def readout(self, z, x, edge_index, edge_attr, msk=None):
        if msk is None:
            return torch.sigmoid(torch.mean(z, 0))
        else:
            return torch.sigmoid(torch.sum(z[msk], 0) / torch.sum(msk))

    def corrupt(self, x, edge_index, edge_attr, msk=None):
        shuffled_rows = torch.randperm(len(x))
        shuffled_x = x[shuffled_rows, :]
        return shuffled_x, edge_index, edge_attr

    def evaluate_loss(self, data, mode):
        # use masking for loss evaluation
        pos_z_train, neg_z_train, summ_train = self.dgi(data.x, data.edge_index, data.edge_attr, msk=data.train_mask)
        pos_z_test, neg_z_test, summ_test = self.dgi(data.x, data.edge_index, data.edge_attr, msk=data.test_mask)

        if mode == 'train':
            return self.dgi.loss(pos_z_train, neg_z_train, summ_train)
        else:
            return self.dgi.loss(pos_z_test, neg_z_test, summ_test)

    def train(self, data):
        # training
        self.dgi.train()
        self.optimizer.zero_grad()
        loss = self.evaluate_loss(data, mode='train')
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def test(self, data):
        # testing
        self.dgi.eval()
        return self.evaluate_loss(data, mode='test').item()
Пример #2
0
def infomax(fp, PARAMS, feature):
    """[generate DGI embedding]

    Args:
        fp ([string]): [file path of the root of the data]
        PARAMS ([dict]): [the parameters of the node2vec model,
                        KEYS: {
                                GRAPH_NAME: the name of the graph file
                                SUMMARY: dimension of embedding,
                                HIDDENCHANNELS: the hidden channel of encoder
                                LEARNING_RATE: learning rate, 
                                BATCH_SIZE: batch size of each batch, 
                                NUM_EPOCH: number of epoch to be trained,
                                CUDA: use GPU
                                }]
        feature ([np.array]): [the node features]

    Returns:
        [np.array]: [numpy array format of the DGI embedding]
    """
    g = io.loadmat(osp.join(fp, 'interim', 'graph', PARAMS['GRAPH_NAME']))
    N = g['N']
    p_cate = feature
    post_indx = g['post_indx']
    edge_idx, x = from_scipy_sparse_matrix(N)
    x = x.view(-1, 1).float()
    feature = np.zeros((x.shape[0], p_cate.shape[1]))
    feature[post_indx, :] = p_cate
    x = torch.cat([x, torch.FloatTensor(feature)], 1)
    data = Data(x=x, edge_index=edge_idx)
    if PARAMS['CUDA']:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = 'cpu'
    data = data.to(device)
    model = DeepGraphInfomax(
        hidden_channels=PARAMS['HIDDEN_CHANNELS'],
        encoder=Encoder(data.x.shape[1], PARAMS['SUMMARY']),
        summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
        corruption=corruption).to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=PARAMS['LEARNING_RATE'])

    def train():
        model.train()
        optimizer.zero_grad()
        pos_z, neg_z, summary = model(data.x, data.edge_index)
        loss = model.loss(pos_z, neg_z, summary)
        loss.backward()
        optimizer.step()
        return loss.item()

    losses = []
    for epoch in range(1, PARAMS['NUM_EPOCH'] + 1):
        loss = train()
        losses.append(loss)
        print('Epoch: {:03d}, Loss: {:.4f}'.format(epoch, loss))
    model.eval()
    with torch.no_grad():
        z, _, _ = model(data.x, data.edge_index)
    if not os.path.exists(os.path.join(fp, 'processed', 'infomax')):
        os.mkdir(os.path.join(fp, 'processed', 'infomax'))
    with open(
            osp.join(fp, 'processed', 'infomax',
                     PARAMS['EMBEDDING_NAME'] + 'log.json'), 'w') as f:
        json.dump({'loss': losses}, f)
    z = z.detach().cpu().numpy()[post_indx.reshape(-1, ), :]
    np.save(osp.join(fp, 'processed', 'infomax', PARAMS['EMBEDDING_NAME']), z)
    print('embedding infomax created')
    return z