示例#1
0
文件: models.py 项目: wxEdward/GIAAD
class GIN(torch.nn.Module):
    """
    多層化対応モデル(AutoGraphで使われていたモデル)
    グラフ構造が重要なので最初から畳み込み層に入力する
    """
    def __init__(self,
                 num_node_features=100,
                 num_class=18,
                 hidden=16,
                 dropout_rate=0.5,
                 num_layers=2,
                 eps=0,
                 train_eps=True):
        super(GIN, self).__init__()
        self.first_conv = GINConv(
            Sequential(Linear(num_node_features, hidden), ReLU(),
                       Linear(hidden, hidden)), eps, train_eps)
        self.first_bn = BatchNorm1d(hidden)
        self.nns = torch.nn.ModuleList()
        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        for i in range(num_layers):
            self.nns.append(
                Sequential(Linear(hidden, hidden), ReLU(),
                           Linear(hidden, hidden)))
            self.bns.append(BatchNorm1d(hidden))
            self.convs.append(GINConv(self.nns[i], eps, train_eps))
        self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, num_class)
        self.dropout_rate = dropout_rate

    def reset_parameters(self):
        self.first_conv.reset_parameters()
        self.first_bn.reset_parameters()
        for nn, conv, bn in zip(self.nns, self.convs, self.bns):
            nn.reset_parameters()
            conv.reset_parameters()
            bn.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
        #x = F.relu(self.first_lin(x))
        x = F.relu(self.first_conv(x, edge_index))
        x = self.first_bn(x)
        for conv, bn in zip(self.convs, self.bns):
            x = F.relu(conv(x, edge_index))
            x = bn(x)
            #x = F.dropout(x, self.dropout_rate, training=self.training)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, self.dropout_rate, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
示例#2
0
class GIN0WithJK(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden, mode='cat'):
        super(GIN0WithJK, self).__init__()
        self.conv1 = GINConv(Sequential(
            Linear(dataset.num_features, hidden),
            ReLU(),
            Linear(hidden, hidden),
            ReLU(),
            BN(hidden),
        ),
                             train_eps=False)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(Sequential(
                    Linear(hidden, hidden),
                    ReLU(),
                    Linear(hidden, hidden),
                    ReLU(),
                    BN(hidden),
                ),
                        train_eps=False))
        self.jump = JumpingKnowledge(mode)
        if mode == 'cat':
            self.lin1 = Linear(num_layers * hidden, hidden)
        else:
            self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.jump.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        xs = [x]
        for conv in self.convs:
            x = conv(x, edge_index)
            xs += [x]
        x = self.jump(xs)
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
示例#3
0
class GIN0(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden, add_pool=False):
        super(GIN0, self).__init__()
        self.conv1 = GINConv(Sequential(
            Linear(dataset.num_features, hidden),
            ReLU(),
            Linear(hidden, hidden),
            ReLU(),
            BN(hidden),
        ),
                             train_eps=False)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(Sequential(
                    Linear(hidden, hidden),
                    ReLU(),
                    Linear(hidden, hidden),
                    ReLU(),
                    BN(hidden),
                ),
                        train_eps=False))
        self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)
        self.add_pool = add_pool

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        for conv in self.convs:
            x = conv(x, edge_index)

        if self.add_pool:
            x = global_add_pool(x, batch)
        else:
            x = global_mean_pool(x, batch)

        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def __repr__(self):
        return self.__class__.__name__
class GINLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GINLayer, self).__init__()
        seq = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels),
        )
        self.conv = GINConv(seq)
        self.bn = nn.BatchNorm1d(out_channels)

    def reset_parameters(self):
        self.conv.reset_parameters()
        # TODO: reset bn?

    def forward(self, x, edge_index):
        x = F.relu(self.conv(x, edge_index))
        x = self.bn(x)
        return x
示例#5
0
class GIN(torch.nn.Module):
    def __init__(self, num_layers=3, hidden=32, features_num=32, num_class=2):
        super(GIN, self).__init__()

        self.lin1 = Linear(hidden, hidden)
        self.lin3 = Linear(hidden, hidden)
        self.conv1 = GINConv(self.lin1)
        self.conv2 = GINConv(self.lin3)
        self.lin2 = Linear(hidden, num_class)

        self.fuse_weight = torch.nn.Parameter(torch.FloatTensor(num_layers),requires_grad=True)
        self.fuse_weight.data.fill_(float(1) / (3))
        self.first_lin = Linear(features_num, hidden)

    def reset_parameters(self):
        self.first_lin.reset_parameters()
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
        x = F.relu(self.first_lin(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x_first = x
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.2, training=self.training)

        x = x + self.fuse_weight[0] * x_first
        x_first = x
        x = F.relu(self.conv2(x, edge_index))    #, edge_attr=edge_weight))
        x = F.dropout(x, p=0.2, training=self.training)
        x = x + self.fuse_weight[1] * x_first
        #x=x+first_x
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)
示例#6
0
class GIN(torch.nn.Module):
    def __init__(self, args):
        super(GIN, self).__init__()
        self.args = set_default(args, {
                    'num_layers': 2,
                    'hidden': 64,
                    'hidden2': 32,
                    'dropout': 0.5,
                    'lr': 0.005,
                    'epoches': 300,
                    'weight_decay': 5e-4,
                    'act': 'leaky_relu',
                    'withbn': True,
                        })
        self.timer = self.args['timer']
        self.dropout = self.args['dropout']
        self.agg = self.args['agg']
        self.conv1 = GINConv(
                        Sequential(
                            Linear(self.args['features_num'], self.args['hidden']),
                            ReLU(),
                            BatchNorm1d(self.args['hidden']),
                        ),
                        train_eps=True)
        self.convs = torch.nn.ModuleList()
        hd = [self.args['hidden']]
        for i in range(self.args['num_layers'] - 1):
            hd.append(self.args['hidden2'])
            self.convs.append(
                    GINConv(
                        Sequential(
                            Linear(self.args['hidden'], self.args['hidden2']),
                            ReLU(),
                            BatchNorm1d(self.args['hidden2']),
                        ),
                        train_eps=True))
        if self.args['agg'] == 'concat':
            outdim = sum(hd)
        elif self.args['agg'] == 'self':
            outdim = hd[-1]
        if self.args['act'] == 'leaky_relu':
            self.act = F.leaky_relu
        elif self.args['act'] == 'tanh':
            self.act = torch.tanh
        else:
            self.act = lambda x: x
        self.lin2 = Linear(outdim, self.args['num_class'])

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
        x = self.act(self.conv1(x, edge_index))
        xs = [x]
        for conv in self.convs:
            x = self.act(conv(x, edge_index))
            xs.append(x)
        #x = F.dropout(x, p=self.dropout, training=self.training)
        if self.agg == 'concat':
            x = torch.cat(xs, dim=1)
        elif self.agg == 'self':
            x = xs[-1]
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)

    def train_predict(self, data, train_mask=None, val_mask=None, return_out=True):
        if train_mask is None:
            train_mask = data.train_mask
        optimizer = torch.optim.Adam(self.parameters(), lr=self.args['lr'], weight_decay=self.args['weight_decay'])
        flag_end = False
        st = time.time()
        for epoch in range(1, self.args['epoches']):
            self.train()
            optimizer.zero_grad()
            res = self.forward(data)
            loss = F.nll_loss(res[train_mask], data.y[train_mask])
            loss.backward()
            optimizer.step()
            if epoch%50 == 0:
                cost = (time.time()-st)/epoch*50
                if max(cost*10, 5) > self.timer.remain_time():
                    flag_end = True
                    break

        test_mask = data.test_mask
        self.eval()
        with torch.no_grad():
            res = self.forward(data)
            if return_out:
                pred = res
            else:
                pred = res[test_mask]
            if val_mask is not None:
                return pred, res[val_mask], flag_end
        return pred, flag_end
 
    def __repr__(self):
        return self.__class__.__name__
示例#7
0
class JK(nn.Module):
    def __init__(self,
                 nfeat,
                 nhid,
                 nclass,
                 dropout=0.5,
                 lr=0.01,
                 weight_decay=5e-4,
                 n_edge=1,
                 with_relu=True,
                 drop=False,
                 with_bias=True,
                 device=None):

        super(JK, self).__init__()

        assert device is not None, "Please specify 'device'!"
        self.device = device
        self.nfeat = nfeat
        self.hidden_sizes = [nhid]
        self.nclass = int(nclass)
        self.dropout = dropout
        self.lr = lr
        self.drop = drop
        if not with_relu:
            self.weight_decay = 0
        else:
            self.weight_decay = weight_decay
        self.with_relu = with_relu
        self.with_bias = with_bias
        self.n_edge = n_edge
        self.output = None
        self.best_model = None
        self.best_output = None
        self.adj_norm = None
        self.features = None
        self.gate = Parameter(torch.rand(1))  # creat a generator between [0,1]
        # self.beta = Parameter(torch.Tensor(self.n_edge))
        nclass = int(nclass)
        """JK from torch-geometric"""
        num_features = nfeat
        dim = nhid
        nn1 = Sequential(
            Linear(num_features, dim),
            ReLU(),
        )
        self.gc1 = GINConv(nn1)
        self.bn1 = torch.nn.BatchNorm1d(dim)

        nn2 = Sequential(
            Linear(dim, dim),
            ReLU(),
        )
        self.gc2 = GINConv(nn2)
        nn3 = Sequential(
            Linear(dim, dim),
            ReLU(),
        )
        self.gc3 = GINConv(nn3)

        self.jump = JumpingKnowledge(mode='cat')  # 'cat', 'lstm', 'max'
        self.bn2 = torch.nn.BatchNorm1d(dim)
        # self.fc1 = Linear(dim*3, dim)
        self.fc2 = Linear(dim * 2, int(nclass))

    def forward(self, x, adj):
        """we don't change the edge_index, just update the edge_weight;
        some edge_weight are regarded as removed if it equals to zero"""
        x = x.to_dense()
        edge_index = adj._indices()
        """GJK-Nets"""
        if self.attention:
            adj = self.att_coef(x, adj, i=0)
        x1 = F.relu(
            self.gc1(x, edge_index=edge_index, edge_weight=adj._values()))
        if self.attention:  # if attention=True, use attention mechanism
            adj_2 = self.att_coef(x1, adj, i=1)
            adj_values = self.gate * adj._values() + (
                1 - self.gate) * adj_2._values()
        else:
            adj_values = adj._values()
        x1 = F.dropout(x1, self.dropout, training=self.training)
        x2 = F.relu(self.gc2(x1, edge_index=edge_index,
                             edge_weight=adj_values))
        # x2 = self.bn1(x2)
        # if self.attention:  # if attention=True, use attention mechanism
        #     adj_3 = self.att_coef(x2, adj, i=1)
        #     adj_values = self.gate * adj_2._values() + (1 - self.gate) * adj_3._values()
        # else:
        #     adj_values = adj._values()
        x2 = F.dropout(x2, self.dropout, training=self.training)
        # x3 = F.relu(self.gc2(x2, edge_index=edge_index, edge_weight=adj_values))
        # x3 = F.dropout(x3, self.dropout, training=self.training)

        x_last = self.jump([x1, x2])
        x_last = F.dropout(x_last, self.dropout, training=self.training)
        x_last = self.fc2(x_last)

        return F.log_softmax(x_last, dim=1)

    def initialize(self):
        self.gc1.reset_parameters()
        self.gc2.reset_parameters()
        self.fc2.reset_parameters()
        try:
            self.jump.reset_parameters()
            self.fc1.reset_parameters()
            self.gc3.reset_parameters()
        except:
            pass

    def att_coef(self, fea, edge_index, is_lil=False, i=0):
        if is_lil == False:
            edge_index = edge_index._indices()
        else:
            edge_index = edge_index.tocoo()

        n_node = fea.shape[0]
        row, col = edge_index[0].cpu().data.numpy()[:], edge_index[1].cpu(
        ).data.numpy()[:]
        # row, col = edge_index[0], edge_index[1]

        fea_copy = fea.cpu().data.numpy()
        sim_matrix = cosine_similarity(X=fea_copy,
                                       Y=fea_copy)  # try cosine similarity
        sim = sim_matrix[row, col]
        sim[sim < 0.1] = 0
        # print('dropped {} edges'.format(1-sim.nonzero()[0].shape[0]/len(sim)))

        # """use jaccard for binary features and cosine for numeric features"""
        # fea_start, fea_end = fea[edge_index[0]], fea[edge_index[1]]
        # isbinray = np.array_equal(fea_copy, fea_copy.astype(bool))  # check is the fea are binary
        # np.seterr(divide='ignore', invalid='ignore')
        # if isbinray:
        #     fea_start, fea_end = fea_start.T, fea_end.T
        #     sim = jaccard_score(fea_start, fea_end, average=None)  # similarity scores of each edge
        # else:
        #     fea_copy[np.isinf(fea_copy)] = 0
        #     fea_copy[np.isnan(fea_copy)] = 0
        #     sim_matrix = cosine_similarity(X=fea_copy, Y=fea_copy)  # try cosine similarity
        #     sim = sim_matrix[edge_index[0], edge_index[1]]
        #     sim[sim < 0.01] = 0
        """build a attention matrix"""
        att_dense = lil_matrix((n_node, n_node), dtype=np.float32)
        att_dense[row, col] = sim
        if att_dense[0, 0] == 1:
            att_dense = att_dense - sp.diags(
                att_dense.diagonal(), offsets=0, format="lil")
        # normalization, make the sum of each row is 1
        att_dense_norm = normalize(att_dense, axis=1, norm='l1')
        """add learnable dropout, make character vector"""
        if self.drop:
            character = np.vstack(
                (att_dense_norm[row, col].A1, att_dense_norm[col, row].A1))
            character = torch.from_numpy(character.T)
            drop_score = self.drop_learn_1(character)
            drop_score = torch.sigmoid(
                drop_score
            )  # do not use softmax since we only have one element
            mm = torch.nn.Threshold(0.5, 0)
            drop_score = mm(drop_score)
            mm_2 = torch.nn.Threshold(-0.49, 1)
            drop_score = mm_2(-drop_score)
            drop_decision = drop_score.clone().requires_grad_()
            # print('rate of left edges', drop_decision.sum().data/drop_decision.shape[0])
            drop_matrix = lil_matrix((n_node, n_node), dtype=np.float32)
            drop_matrix[row,
                        col] = drop_decision.cpu().data.numpy().squeeze(-1)
            att_dense_norm = att_dense_norm.multiply(
                drop_matrix.tocsr())  # update, remove the 0 edges

        if att_dense_norm[
                0,
                0] == 0:  # add the weights of self-loop only add self-loop at the first layer
            degree = (att_dense_norm != 0).sum(1).A1
            # degree = degree.squeeze(-1).squeeze(-1)
            lam = 1 / (degree + 1)  # degree +1 is to add itself
            self_weight = sp.diags(np.array(lam), offsets=0, format="lil")
            att = att_dense_norm + self_weight  # add the self loop
        else:
            att = att_dense_norm

        att_adj = edge_index
        att_edge_weight = att[row, col]
        att_edge_weight = np.exp(att_edge_weight)  # exponent, kind of softmax
        att_edge_weight = torch.tensor(np.array(att_edge_weight)[0],
                                       dtype=torch.float32).cuda()

        shape = (n_node, n_node)
        new_adj = torch.sparse.FloatTensor(att_adj, att_edge_weight, shape)
        return new_adj

    def add_loop_sparse(self, adj, fill_value=1):
        # make identify sparse tensor
        row = torch.range(0, int(adj.shape[0] - 1), dtype=torch.int64)
        i = torch.stack((row, row), dim=0)
        v = torch.ones(adj.shape[0], dtype=torch.float32)
        shape = adj.shape
        I_n = torch.sparse.FloatTensor(i, v, shape)
        return adj + I_n.to(self.device)

    def fit(
        self,
        features,
        adj,
        labels,
        idx_train,
        idx_val=None,
        idx_test=None,
        train_iters=81,
        att_0=None,
        attention=False,
        model_name=None,
        initialize=True,
        verbose=False,
        normalize=False,
        patience=500,
    ):
        '''
            train the gcn model, when idx_val is not None, pick the best model
            according to the validation loss
        '''
        self.sim = None
        self.attention = attention
        self.idx_test = idx_test
        # self.device = self.gc1.weight.device

        if initialize:
            self.initialize()

        if type(adj) is not torch.Tensor:
            features, adj, labels = utils.to_tensor(features,
                                                    adj,
                                                    labels,
                                                    device=self.device)
        else:
            features = features.to(self.device)
            adj = adj.to(self.device)
            labels = labels.to(self.device)

        # normalize = False # we don't need normalize here, the norm is conducted in the GCN (self.gcn1) model
        # if normalize:
        #     if utils.is_sparse_tensor(adj):
        #         adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
        #     else:
        #         adj_norm = utils.normalize_adj_tensor(adj)
        # else:
        #     adj_norm = adj

        adj = self.add_loop_sparse(adj)
        """Make the coefficient D^{-1/2}(A+I)D^{-1/2}"""
        self.adj_norm = adj
        self.features = features
        self.labels = labels

        if idx_val is None:
            self._train_without_val(labels, idx_train, train_iters, verbose)
        else:
            if patience < train_iters:
                self._train_with_early_stopping(labels, idx_train, idx_val,
                                                train_iters, patience, verbose)
            else:
                self._train_with_val(labels, idx_train, idx_val, train_iters,
                                     verbose)

    def _train_without_val(self, labels, idx_train, train_iters, verbose):
        self.train()
        optimizer = optim.Adam(self.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)
        for i in range(train_iters):
            optimizer.zero_grad()
            output = self.forward(self.features, self.adj_norm)
            loss_train = F.nll_loss(
                output[idx_train], labels[idx_train], weight=None
            )  # this weight is the weight of each training nodes
            loss_train.backward()
            optimizer.step()
            if verbose and i % 10 == 0:
                print('Epoch {}, training loss: {}'.format(
                    i, loss_train.item()))

        self.eval()
        output = self.forward(self.features, self.adj_norm)
        self.output = output

    def _train_with_val(self, labels, idx_train, idx_val, train_iters,
                        verbose):
        if verbose:
            print('=== training gcn model ===')
        optimizer = optim.Adam(self.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        best_loss_val = 100
        best_acc_val = 0

        for i in range(train_iters):
            self.train()
            optimizer.zero_grad()
            output = self.forward(self.features, self.adj_norm)
            loss_train = F.nll_loss(output[idx_train], labels[idx_train])
            loss_train.backward()
            optimizer.step()

            # pred = output[self.idx_test].max(1)[1]
            # acc_test =accuracy(output[self.idx_test], labels[self.idx_test])
            # acc_test = pred.eq(labels[self.idx_test]).sum().item() / self.idx_test.shape[0]

            self.eval()
            output = self.forward(self.features, self.adj_norm)
            loss_val = F.nll_loss(output[idx_val], labels[idx_val])
            acc_val = utils.accuracy(output[idx_val], labels[idx_val])

            if verbose and i % 20 == 0:
                print('Epoch {}, training loss: {}, test acc: {}'.format(
                    i, loss_train.item(), acc_val))

            if best_loss_val > loss_val:
                best_loss_val = loss_val
                self.output = output
                weights = deepcopy(self.state_dict())

            if acc_val > best_acc_val:
                best_acc_val = acc_val
                self.output = output
                weights = deepcopy(self.state_dict())

        if verbose:
            print(
                '=== picking the best model according to the performance on validation ==='
            )
        self.load_state_dict(weights)

    def _train_with_early_stopping(self, labels, idx_train, idx_val,
                                   train_iters, patience, verbose):
        if verbose:
            print('=== training gcn model ===')
        optimizer = optim.Adam(self.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)

        early_stopping = patience
        best_loss_val = 100

        for i in range(train_iters):
            self.train()
            optimizer.zero_grad()
            output = self.forward(self.features, self.adj_norm)
            loss_train = F.nll_loss(output[idx_train], labels[idx_train])
            loss_train.backward()
            optimizer.step()

            self.eval()
            output = self.forward(self.features, self.adj_norm)

            if verbose and i % 10 == 0:
                print('Epoch {}, training loss: {}'.format(
                    i, loss_train.item()))

            loss_val = F.nll_loss(output[idx_val], labels[idx_val])

            if best_loss_val > loss_val:
                best_loss_val = loss_val
                self.output = output
                weights = deepcopy(self.state_dict())
                patience = early_stopping
            else:
                patience -= 1
            if i > early_stopping and patience <= 0:
                break

        if verbose:
            print('=== early stopping at {0}, loss_val = {1} ==='.format(
                i, best_loss_val))
        self.load_state_dict(weights)

    def test(self, idx_test, model_name=None):
        # self.model_name = model_name
        self.eval()
        output = self.predict()
        # output = self.output
        loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
        acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
        print("Test set results:", "loss= {:.4f}".format(loss_test.item()),
              "accuracy= {:.4f}".format(acc_test.item()))
        return acc_test, output

    def _set_parameters(self):
        # TODO
        pass

    def predict(self, features=None, adj=None):
        '''By default, inputs are unnormalized data'''

        # self.eval()
        if features is None and adj is None:
            return self.forward(self.features, self.adj_norm)
        else:
            if type(adj) is not torch.Tensor:
                features, adj = utils.to_tensor(features,
                                                adj,
                                                device=self.device)

            self.features = features
            if utils.is_sparse_tensor(adj):
                self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
            else:
                self.adj_norm = utils.normalize_adj_tensor(adj)
            return self.forward(self.features, self.adj_norm)
示例#8
0
class cut_MPNN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden1, hidden2, deltas, elasticity=0.01, num_iterations = 30):
        super(cut_MPNN, self).__init__()
        self.hidden1 = hidden1
        self.hidden2 = hidden2
        self.conv1 = GINConv(Sequential(
            Linear(1,  self.hidden1),
            ReLU(),
            Linear(self.hidden1, self.hidden1),
            ReLU(),
            BN( self.hidden1),
        ),train_eps=False)
        self.num_iterations = num_iterations
        self.convs = torch.nn.ModuleList()
        self.deltas = deltas
        self.numlayers = num_layers
        self.elasticity = elasticity
        
        self.bns = torch.nn.ModuleList()
        for i in range(num_layers-1):
            self.bns.append(BN( self.hidden1))
        self.convs = torch.nn.ModuleList()        
        for i in range(num_layers - 1):
                self.convs.append(GINConv(Sequential(
            Linear( self.hidden1,  self.hidden1),
            ReLU(),
            Linear( self.hidden1,  self.hidden1),
            ReLU(),
            BN(self.hidden1),
        ),train_eps=False))
     
        self.conv2 = GATAConv( self.hidden1, self.hidden2 ,heads=8)
        self.lin1 = Linear(8*self.hidden2, self.hidden1)
        self.bn2 = BN(self.hidden1)
        self.lin2 = Linear(self.hidden1, 1)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters() 
        for conv in self.convs:
            conv.reset_parameters()    
        for bn in self.bns:
            bn.reset_parameters()
        self.lin1.reset_parameters()
        self.bn2.reset_parameters()
        self.lin2.reset_parameters()


    def forward(self, data, tvol = None):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch 
        xinit= x.clone()
        row, col = edge_index
        mask = get_mask(x,edge_index,1).to(x.dtype).unsqueeze(-1)

        x = self.conv1(x, edge_index)
        xpostconv1 = x.detach() 
        x = x*mask
        for conv, bn in zip(self.convs, self.bns):
            if(x.dim()>1):
                x = x + conv(x, edge_index)
                mask = get_mask(mask,edge_index,1).to(x.dtype)
                x = x*mask
                x = bn(x)


        x = self.conv2(x, edge_index)
        mask = get_mask(mask,edge_index,1).to(x.dtype)
        x = x*mask
        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x)) 
        x = x*mask
        x = self.bn2(x)

        xpostlin1 = x.detach()
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.leaky_relu(self.lin2(x)) 
        x = x*mask
        

        xprethresh = x.detach()
        N_size = x.shape[0]    
        batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)
        batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)
        
        #min-max normalize       
        x = (x-batch_min)/(batch_max+1e-6-batch_min)
        x = x*mask + mask*1e-6
        

        #add dirac in the set
        x = x + xinit.unsqueeze(-1)
        
        #calculate
        x2 = x.detach()              
        r, c = edge_index
        tv = total_var(x, edge_index, batch)
        deg = degree(r).unsqueeze(-1) 
        conduct_1 = (tv)
        totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6
        totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6
        
                
        #receptive field
        recvol_hard = scatter_add(deg*mask.float(), batch, 0, dim_size = batch.max().item()+1)+1e-6 
        reccard_hard = scatter_add(mask.float(), batch, 0, dim_size = batch.max().item()+1)+1e-6 
        
        assert recvol_hard.mean()/totalvol.mean() <=1, "Something went wrong! Receptive field is larger than total volume."
        target = torch.zeros_like(totalvol)
        
        #generate target vol
        if tvol is None:
            feasible_vols = data.recfield_vol/data.total_vol-0.0
            target = torch.rand_like(feasible_vols, device=device)*feasible_vols*0.85 + 0.1
            target = target.squeeze(-1)*totalvol.squeeze(-1)
        else:
            target = tvol*totalvol.squeeze(-1)
        a = torch.ones((batch.max().item()+1,1), device = device)
        xfilt = x
                
        
        ###############################################################################
        #iterative rescaling
        counter_no2 = 0
        for iteration in range(self.num_iterations):
            counter_no2 += 1
            keep = (((a[batch]*xfilt)<1).to(x.dtype))

            
            x_k, d_k, d_nk = xfilt*keep*mask, deg*keep*mask, deg*(1-keep)*mask
            
            
            diff = target.unsqueeze(-1) - scatter_add(d_nk, batch, 0)
            dot = scatter_add(x_k*d_k, batch, 0)
            a = diff/(dot+1e-5)
            volcur = (scatter_add(torch.clamp(a[batch]*xfilt,max = 1., min = 0.)*deg,batch,0))

            volcheck = (torch.abs(target - volcur.squeeze(-1))>0.1)
            checki = torch.abs(target.squeeze(-1)-volcur.squeeze(-1))>0.01

            targetcheck = torch.abs(volcur.squeeze(-1) - target)
            
            check = (targetcheck<= self.elasticity*target).to(x.dtype)

            if (tvol is not None):
                pass
            if(check.sum()>=batch.max().item()+1):
                break;
        
        probs = torch.clamp(a[batch]*x*mask, max = 1., min = 0.)
        ###############################################################################

            
            
        #collect useful numbers    
        x2 =  ((probs - torch.rand_like(x, device = device))>0).float()         
        vol_1 = scatter_add(probs*deg, batch, 0)+1e-6
        card_1 = scatter_add(probs, batch,0) 
        rec_field = scatter_add(mask, batch, 0)+1e-6
        cut_size = scatter_add(x2, batch, 0)
        tv_hard = total_var(x2, edge_index, batch)
        vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 
        conduct_hard = tv_hard/vol_hard         
        rec_field_ratio = cut_size/rec_field
        rec_field_volratio = vol_hard/recvol_hard
        total_vol_ratio = vol_hard/totalvol
        
        #calculate loss
        expected_cut = scatter_add(probs*deg, batch, 0) - scatter_add((probs[row]*probs[col]), batch[row], 0)   
        loss = expected_cut   


        #return dict 
        retdict = {}
        retdict["output"] = [probs.squeeze(-1),"hist"]   #output
        #retdict["|Expected_vol - Target|"]= [targetcheck, "sequence"] #absolute distance from targetvol
        retdict["Expected_volume"] = [vol_1.mean(),"sequence"] #volume
        retdict["Expected_cardinality"] = [card_1.mean(),"sequence"]
        retdict["volume_hard"] = [vol_hard.mean(),"sequence"] #volume2
        #retdict["cut1"] = [tv.mean(),"sequence"] #cut1
        retdict["cut_hard"] = [tv_hard.mean(),"sequence"] #cut1
        retdict["Average cardinality ratio of receptive field "] = [rec_field_ratio.mean(),"sequence"] 
        retdict["Recfield volume/Total volume"] = [recvol_hard.mean()/totalvol.mean(), "sequence"]
        retdict["Average ratio of receptive field volume"]= [rec_field_volratio.mean(),'sequence']
        retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence']
        retdict["mask"] = [mask, "aux"] #mask
        retdict["xinit"] = [xinit,"hist"] #layer input diracs
        retdict["xpostlin1"] = [xpostlin1.mean(1),"hist"] #after first linear layer
        retdict["xprethresh"] = [xprethresh.mean(1),"hist"] #pre thresholding activations 195 x 1
        retdict["lossvol"] = [lossvol.mean(),"sequence"] #volume constraint
        retdict["losscard"] = [losscard.mean(),"sequence"] #cardinality constraint
        retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss

        return retdict
    
    def __repr__(self):
        return self.__class__.__name__
示例#9
0
class clique_MPNN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden1, hidden2, deltas, elasticity=0.01, num_iterations = 30):
        super(cliqueMPNN_hindsight_earlyGAT, self).__init__()
        self.hidden1 = hidden1
        self.hidden2 = hidden2
        self.momentum = 0.1
        self.num_iterations = num_iterations
        self.convs = torch.nn.ModuleList()
        self.deltas = deltas
        self.numlayers = num_layers
        self.elasticity = elasticity
        self.heads = 8
        self.concat = True
        
        self.bns = torch.nn.ModuleList()
        for i in range(num_layers-1):
            self.bns.append(BN(self.heads*self.hidden1, momentum=self.momentum))
        self.convs = torch.nn.ModuleList()        
        for i in range(num_layers - 1):
                self.convs.append(GINConv(Sequential(
            Linear( self.heads*self.hidden1,  self.heads*self.hidden1),
            ReLU(),
            Linear( self.heads*self.hidden1,  self.heads*self.hidden1),
            ReLU(),
            BN(self.heads*self.hidden1, momentum=self.momentum),
        ),train_eps=True))
        self.bn1 = BN(self.heads*self.hidden1)       
        self.conv1 = GINConv(Sequential(Linear(self.hidden2,  self.heads*self.hidden1),
            ReLU(),
            Linear( self.heads*self.hidden1,  self.heads*self.hidden1),
            ReLU(),
            BN(self.heads*self.hidden1, momentum=self.momentum),
        ),train_eps=True)

        if self.concat:
            self.lin1 = Linear(self.heads*self.hidden1, self.hidden1)
        else:
            self.lin1 = Linear(self.hidden1, self.hidden1)
        self.lin2 = Linear(self.hidden1, 1)
        self.gnorm = GraphSizeNorm()

                    


    def reset_parameters(self):
        self.conv1.reset_parameters()
        
        for conv in self.convs:
            conv.reset_parameters() 
        for bn in self.bns:
            bn.reset_parameters()
        self.bn1.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()






    def forward(self, data, edge_dropout = None, penalty_coefficient = 0.25):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        num_graphs = batch.max().item() + 1
        row, col = edge_index     
        total_num_edges = edge_index.shape[1]
        N_size = x.shape[0]

        
        if edge_dropout is not None:
            edge_index = dropout_adj(edge_index, edge_attr = (torch.ones(edge_index.shape[1], device=device)).long(), p = edge_dropout, force_undirected=True)[0]
            edge_index = add_remaining_self_loops(edge_index, num_nodes = batch.shape[0])[0]
                
        reduced_num_edges = edge_index.shape[1]
        current_edge_percentage = (reduced_num_edges/total_num_edges)
        no_loop_index,_ = remove_self_loops(edge_index)  
        no_loop_row, no_loop_col = no_loop_index

        xinit= x.clone()
        x = x.unsqueeze(-1)
        mask = get_mask(x,edge_index,1).to(x.dtype)
        x = F.leaky_relu(self.conv1(x, edge_index))# +x
        x = x*mask
        x = self.gnorm(x)
        x = self.bn1(x)
        
            
        for conv, bn in zip(self.convs, self.bns):
            if(x.dim()>1):
                x =  x+F.leaky_relu(conv(x, edge_index))
                mask = get_mask(mask,edge_index,1).to(x.dtype)
                x = x*mask
                x = self.gnorm(x)
                x = bn(x)

        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x)) 
        x = x*mask


        xpostlin1 = x.detach()
        x = F.leaky_relu(self.lin2(x)) 
        x = x*mask


        #calculate min and max
        batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)        
        batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)

        #min-max normalize
        x = (x-batch_min)/(batch_max+1e-6-batch_min)
        probs=x
           
        x2 = x.detach()              
        deg = degree(row).unsqueeze(-1) 
        totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6
        totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6               
        x2 =  ((x2 - torch.rand_like(x, device = device))>0).float()    
        vol_1 = scatter_add(probs*deg, batch, 0)+1e-6
        card_1 = scatter_add(probs, batch,0)            
        set_size = scatter_add(x2, batch, 0)
        vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 
        total_vol_ratio = vol_hard/totalvol
        
        
        #calculating the terms for the expected distance between clique and graph
        pairwise_prodsums = torch.zeros(num_graphs, device = device)
        for graph in range(num_graphs):
            batch_graph = (batch==graph)
            pairwise_prodsums[graph] = (torch.conv1d(probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum()/2
        
        
        ###calculate loss terms
        self_sums = scatter_add((probs*probs), batch, 0, dim_size = num_graphs)
        expected_weight_G = scatter_add(probs[no_loop_row]*probs[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2.
        expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums)/1.
        expected_distance = (expected_clique_weight - expected_weight_G)        
        
        
        ###useful numbers 
        max_set_weight = (scatter_add(torch.ones_like(x)[no_loop_row], batch[no_loop_row], 0, dim_size = num_graphs)/2).squeeze(-1)                
        set_weight = (scatter_add(x2[no_loop_row]*x2[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2)+1e-6
        clique_edges_hard = (set_size*(set_size-1)/2) +1e-6
        clique_dist_hard = set_weight/clique_edges_hard
        clique_check = ((clique_edges_hard != clique_edges_hard))
        setedge_check  = ((set_weight != set_weight))      
        
        assert ((clique_dist_hard>=1.1).sum())<=1e-6, "Invalid set vol/clique vol ratio."

        ###calculate loss
        expected_loss = (penalty_coefficient)*expected_distance*0.5 - 0.5*expected_weight_G  
        

        loss = expected_loss


        retdict = {}
        
        retdict["output"] = [probs.squeeze(-1),"hist"]   #output
        retdict["Expected_cardinality"] = [card_1.mean(),"sequence"]
        retdict["Expected_cardinality_hist"] = [card_1,"hist"]
        retdict["losses histogram"] = [loss.squeeze(-1),"hist"]
        retdict["Set sizes"] = [set_size.squeeze(-1),"hist"]
        retdict["volume_hard"] = [vol_hard.mean(),"aux"] #volume2
        retdict["cardinality_hard"] = [set_size[0],"sequence"] #volumeq
        retdict["Expected weight(G)"]= [expected_weight_G.mean(), "sequence"]
        retdict["Expected maximum weight"] = [expected_clique_weight.mean(),"sequence"]
        retdict["Expected distance"]= [expected_distance.mean(), "sequence"]
        retdict["Currvol/Cliquevol"] = [clique_dist_hard.mean(),'sequence']
        retdict["Currvol/Cliquevol all graphs in batch"] = [clique_dist_hard.squeeze(-1),'hist']
        retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence']
        retdict["cardinalities"] = [cardinalities.squeeze(-1),"hist"]
        retdict["Current edge percentage"] = [torch.tensor(current_edge_percentage),'sequence']
        retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss

        return retdict
    
    def __repr__(self):
        return self.__class__.__name__
示例#10
0
class clique_MPNN(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden1, hidden2, deltas):
        super(clique_MPNN, self).__init__()
        self.hidden1 = hidden1
        self.hidden2 = hidden2
        self.momentum = 0.1
        self.convs = torch.nn.ModuleList()
        self.deltas = deltas
        self.numlayers = num_layers
        self.heads = 8
        self.concat = True

        self.bns = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.bns.append(
                BN(self.heads * self.hidden1, momentum=self.momentum))
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(Sequential(
                    Linear(self.heads * self.hidden1,
                           self.heads * self.hidden1),
                    ReLU(),
                    Linear(self.heads * self.hidden1,
                           self.heads * self.hidden1),
                    ReLU(),
                    BN(self.heads * self.hidden1, momentum=self.momentum),
                ),
                        train_eps=True))
        self.bn1 = BN(self.heads * self.hidden1)
        self.conv1 = GINConv(Sequential(
            Linear(self.hidden2, self.heads * self.hidden1),
            ReLU(),
            Linear(self.heads * self.hidden1, self.heads * self.hidden1),
            ReLU(),
            BN(self.heads * self.hidden1, momentum=self.momentum),
        ),
                             train_eps=True)

        if self.concat:
            self.lin1 = Linear(self.heads * self.hidden1, self.hidden1)
        else:
            self.lin1 = Linear(self.hidden1, self.hidden1)
        self.lin2 = Linear(self.hidden1, 1)
        self.gnorm = GraphSizeNorm()

    def reset_parameters(self):
        self.conv1.reset_parameters()

        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        self.bn1.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data, edge_dropout=None, penalty_coefficient=0.25):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        num_graphs = batch.max().item() + 1
        row, col = edge_index
        total_num_edges = edge_index.shape[1]
        N_size = x.shape[0]

        if edge_dropout is not None:
            edge_index = dropout_adj(
                edge_index,
                edge_attr=(torch.ones(edge_index.shape[1],
                                      device=device)).long(),
                p=edge_dropout,
                force_undirected=True)[0]
            edge_index = add_remaining_self_loops(edge_index,
                                                  num_nodes=batch.shape[0])[0]

        reduced_num_edges = edge_index.shape[1]
        current_edge_percentage = (reduced_num_edges / total_num_edges)
        no_loop_index, _ = remove_self_loops(edge_index)
        no_loop_row, no_loop_col = no_loop_index

        xinit = x.clone()
        x = x.unsqueeze(-1)
        mask = get_mask(x, edge_index, 1).to(x.dtype)
        x = F.leaky_relu(self.conv1(x, edge_index))  # +x
        x = x * mask
        x = self.gnorm(x)
        x = self.bn1(x)

        for conv, bn in zip(self.convs, self.bns):
            if (x.dim() > 1):
                x = x + F.leaky_relu(conv(x, edge_index))
                mask = get_mask(mask, edge_index, 1).to(x.dtype)
                x = x * mask
                x = self.gnorm(x)
                x = bn(x)

        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x))
        x = x * mask

        xpostlin1 = x.detach()
        x = F.leaky_relu(self.lin2(x))
        x = x * mask

        #calculate min and max
        batch_max = scatter_max(x, batch, 0, dim_size=N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)
        batch_min = scatter_min(x, batch, 0, dim_size=N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)

        #min-max normalize
        x = (x - batch_min) / (batch_max + 1e-6 - batch_min)
        probs = x

        #calculating the terms for the expected distance between clique and graph
        pairwise_prodsums = torch.zeros(num_graphs, device=device)
        for graph in range(num_graphs):
            batch_graph = (batch == graph)
            pairwise_prodsums[graph] = (torch.conv1d(
                probs[batch_graph].unsqueeze(-1),
                probs[batch_graph].unsqueeze(-1))).sum() / 2

        ###calculate loss terms
        self_sums = scatter_add((probs * probs), batch, 0, dim_size=num_graphs)
        expected_weight_G = scatter_add(
            probs[no_loop_row] * probs[no_loop_col],
            batch[no_loop_row],
            0,
            dim_size=num_graphs) / 2.
        expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) -
                                  self_sums) / 1.
        expected_distance = (expected_clique_weight - expected_weight_G)

        ###calculate loss
        expected_loss = (penalty_coefficient
                         ) * expected_distance * 0.5 - 0.5 * expected_weight_G

        loss = expected_loss

        retdict = {}

        retdict["output"] = [probs.squeeze(-1), "hist"]  #output
        retdict["losses histogram"] = [loss.squeeze(-1), "hist"]
        retdict["Expected weight(G)"] = [expected_weight_G.mean(), "sequence"]
        retdict["Expected maximum weight"] = [
            expected_clique_weight.mean(), "sequence"
        ]
        retdict["Expected distance"] = [expected_distance.mean(), "sequence"]
        retdict["loss"] = [loss.mean().squeeze(), "sequence"]  #final loss

        return retdict

    def __repr__(self):
        return self.__class__.__name__
示例#11
0
文件: model.py 项目: waljan/GNNpT1
class GIN(torch.nn.Module):
    def __init__(self, num_layers, num_input_features, hidden):
        super(GIN, self).__init__()
        self.conv1 = GINConv(
            torch.nn.Sequential(Linear(num_input_features, hidden),
                                torch.nn.ReLU(), Linear(hidden, hidden)))
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(
                    torch.nn.Sequential(Linear(hidden,
                                               hidden), torch.nn.ReLU(),
                                        Linear(hidden, hidden))))

        self.lin1 = torch.nn.Linear(3 * hidden, hidden)
        self.lin2 = torch.nn.Linear(hidden, 2)

    def reset_parameters(self):  # reset all conv and linear layers
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters(
            )  # .reset_parameters() is method of the torch_geometric.nn.GINConv class
        self.lin1.reset_parameters(
        )  # .reset_parameters() is method of the torch.nn.Linear class
        self.lin2.reset_parameters()

    def forward(self, data):
        # data: Batch(batch=[num_nodes_in_batch],
        #               edge_attr=[2*num_nodes_in_batch,num_edge_features_per_edge],
        #               edge_index=[2,2*num_nodes_in_batch],
        #               pos=[num_nodes_in_batch,2],
        #               x=[num_nodes_in_batch, num_input_features_per_node],
        #               y=[num_graphs_in_batch, num_classes]
        # example: Batch(batch=[2490], edge_attr=[4980,1], edge_index=[2,4980], pos=[2490,2], x=[2490,33], y=[32,2]

        x, edge_index, batch = data.x, data.edge_index, data.batch
        # x.shape: torch.Size([num_nodes_in_batch, num_input_features_per_node])
        # edge_index.shape: torch.Size([2, 2*num_nodes_in_batch])
        # batch.shape: torch.Size([num_nodes_in_batch])
        # example:  x.shape = troch.Size([2490,33])
        #           edge_index.shape = torch.Size([2,4980])
        #           batch.shape = torch.Size([2490])

        # graph convolutions and relu activation
        x = F.relu(self.conv1(x, edge_index))
        # x.shape:  torch.Size([num_nodes_in_batch, hidden])
        # example:  x.shape = torch.Size([2490, 66])

        for conv in self.convs:
            x = F.relu(conv(x, edge_index))

        # x.shape:  torch.Size([num_nodes_in_batch, hidden])
        # example:  x.shape = torch.Size([2490, 66])

        x = torch.cat([
            global_add_pool(x, batch),
            global_mean_pool(x, batch),
            global_max_pool(x, batch)
        ],
                      dim=1)
        # x.shape:  torch.Size([num_graphs_in_batch, 3*hidden)
        # example:  x.shape = torch.Size([32, 3*66])

        # linear layers, activation function, dropout
        x = F.relu(self.lin1(x))
        # x.shape:  torch.Size([num_graphs_in_batch, hidden)
        # example:  x.shape = torch.Size([32, 66])
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        # x.shape:  torch.Size([num_graphs_in_batch, num_classes)
        # example:  x.shape = torch.Size([32, 2])

        output = F.log_softmax(x, dim=-1)

        return output

    def __repr__(self):
        #for getting a printable representation of an object
        return self.__class__.__name__