Example #1
    def forward(self, data, conv_train=False):

        x = data.x

        edge_index = data.edge_index
        x1 = self.norm1(self.act1(self.conv1(x, edge_index)))
        x = self.dropout(x1)

        x2 = self.norm2(self.act2(self.conv2(x, edge_index)))
        x = self.dropout(x2)

        x3 = self.norm3(self.act3(self.conv3(x, edge_index)))

        h_conv = torch.cat([x1, x2, x3], dim=1)

        #compute GNN only output

        conv_batch_avg = gap(h_conv, data.batch)
        conv_batch_add = gadd(h_conv, data.batch)
        conv_batch_max = gmp(h_conv, data.batch)

        h_GNN = torch.cat([conv_batch_avg, conv_batch_add, conv_batch_max],

        gnn_out = self.out_fun(self.lin_GNN(h_GNN))

        if conv_train:
            return None, None, gnn_out

        _, _, som_out_1 = self.som1(x1)
        _, _, som_out_2 = self.som2(x2)
        _, _, som_out_3 = self.som3(x3)

        h1 = self.out_norm1(self.act1(self.out_conv1(som_out_1, edge_index)))
        h2 = self.out_norm2(self.act2(self.out_conv2(som_out_2, edge_index)))
        h3 = self.out_norm3(self.act3(self.out_conv3(som_out_3, edge_index)))

        som_out_conv = torch.cat([h1, h2, h3], dim=1)

        som_batch_avg = gap(som_out_conv, data.batch)
        som_batch_add = gadd(som_out_conv, data.batch)
        som_batch_max = gmp(som_out_conv, data.batch)

        h = torch.cat([som_batch_avg, som_batch_add, som_batch_max], dim=1)

        h = self.out_norm4(h)

        h = self.out_act(self.lin_out1(h))
        h = self.dropout(h)

        h = self.out_act(self.lin_out2(h))
        h = self.dropout(h)

        h = self.out_fun(self.lin_out3(h))

        return h, h_conv, gnn_out
Example #2
File: MRGNN.py Project: lpasa/MRGNN
    def forward(self, data, hidden_layer_aggregator=None):
        X = data.x
        k = self.max_k

        #compute Laplacian
        L_edge_index, L_values = get_laplacian(data.edge_index,
        L = torch.sparse.FloatTensor(L_edge_index, L_values,

        H = [X]
        for i in range(k - 1):
            xhi_layer_i = torch.mm(torch.matrix_power(L, i + 1), X)

        H = self.lin(torch.cat(H, dim=1), self.xhi_layer_mask)
        H = self.reservoir_act_fun(H)
        H = self.bn_hidden_rec(H)

        H_avg = gap(H, data.batch)
        H_add = gadd(H, data.batch)
        H_max = gmp(H, data.batch)

        H = torch.cat([H_avg, H_add, H_max], dim=1)

        if self.output == "funnel" or self.output is None:
            return self.funnel_output(H)
        elif self.output == "one_layer":
            return self.one_layer_out(H)
        elif self.output == "restricted_funnel":
            return self.restricted_funnel_output(H)
            assert False, "error in output stage"
Example #3
    def forward(self, data):
        model forward method
        :param data: current batch
        :return: the model output given the batch
        X = data.x

        edge_index = data.edge_index
        X = self.norm0(self.conv0(X, edge_index))

        k = self.k

        # compute adjacency matrix A
        adjacency_indexes = data.edge_index
        A_rows = adjacency_indexes[0]
        A_data = [1] * A_rows.shape[0]
        v_index = torch.FloatTensor(A_data).to(self.device)
        A_shape = [X.shape[0], X.shape[0]]
        A = torch.sparse.FloatTensor(adjacency_indexes, v_index,

        H = [X]

        # compute the parts of matrix H for each k
        for i in range(k - 1):
            xhi_layer_i = torch.mm(torch.matrix_power(A, i + 1), X)
        # project H by W
        H = self.bn_hidden_rec(
            self.lin(torch.cat(H, dim=1), self.xhi_layer_mask))

        # compute the graph layer representation using 3 different pooling strategies
        H_avg = gap(H, data.batch)
        H_add = gadd(H, data.batch)
        H_max = gmp(H, data.batch)
        H = torch.cat([H_avg, H_add, H_max], dim=1)

        #compute the readout
        if self.output == "funnel" or self.output is None:
            return self.funnel_output(H)
        elif self.output == "restricted_funnel":
            return self.restricted_funnel_output(H)
            assert False, "error in output stage"
Example #4
File: MRGNN.py Project: lpasa/MRGNN
    def readout_fw(self, data):
        H = data.reservoir
        H = self.reservoir_act_fun(H)
        H = self.bn_hidden_rec(H)
        H_avg = gap(H, data.batch)
        H_add = gadd(H, data.batch)
        H_max = gmp(H, data.batch)
        H = torch.cat([H_avg, H_add, H_max],
                      dim=1)  # torch.cat([H_avg, H_add, H_max, H_min],dim=1)

        if self.output == "funnel" or self.output is None:
            return self.funnel_output(H)
        elif self.output == "one_layer":
            return self.one_layer_out(H)
        elif self.output == "restricted_funnel":
            return self.restricted_funnel_output(H)
        elif self.output == "svm":
            return self.svm_output(H)
            assert False, "error in output stage"