Exemple #1
0
    def forward(self, data):
        x, edge_index, batch = data.node_feature, data.edge_index, data.batch
        x = self.pre_mp(x)
        num_nodes = x.size(0)

        # [num nodes x current num layer x hidden_dim]
        all_emb = x.unsqueeze(1)
        # [num nodes x (curr num layer * hidden_dim)]
        emb = x
        for i in range(len(self.convs)):
            if self.args.skip == 'learnable':
                skip_vals = self.learnable_skip[i, :i +
                                                1].unsqueeze(0).unsqueeze(-1)
                curr_emb = all_emb * torch.sigmoid(skip_vals)
                curr_emb = curr_emb.view(num_nodes, -1)
                x = self.convs[i](curr_emb, edge_index)
            if self.args.skip == 'all' or self.args.skip == 'learnable':
                x = self.convs[i](emb, edge_index)
            else:
                x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            emb = torch.cat((emb, x), 1)
            if self.args.skip == 'learnable':
                all_emb = torch.cat((all_emb, x.unsqueeze(1)), 1)

        # x = pyg_nn.global_mean_pool(x, batch)
        emb = pyg_nn.global_add_pool(emb, batch)
        emb = self.post_mp(emb)
        out = F.log_softmax(emb, dim=1)
        return out
    def forward(self, data):
        x = data.x

        # Compute graph convolutional part
        if self.net_type != 'gmmcn':
            for gcn_layer in self.gcn:
                x = F.relu(gcn_layer(x, data.edge_index))
        else:
            for gcn_layer in self.gcn:
                x = F.relu(
                    gcn_layer(x.float(), data.edge_index.long(),
                              data.pseudo.float()))

        # Apply global sum pooling and dropout
        x = global_add_pool(x, data.batch)
        x = self.drop(x)
        embedding = x

        # Compute fully-connected part
        if self.fc_dim > 0:
            x = F.relu(self.fc(x))

        output = self.fc_out(x)  # sigmoid in loss function

        return embedding, output
Exemple #3
0
def test_permuted_global_pool():
    N_1, N_2 = 4, 6
    x = torch.randn(N_1 + N_2, 4)
    batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long)
    perm = torch.randperm(N_1 + N_2)

    px = x[perm]
    pbatch = batch[perm]
    px1 = px[pbatch == 0]
    px2 = px[pbatch == 1]

    out = global_add_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.sum(dim=0))
    assert torch.allclose(out[1], px2.sum(dim=0))

    out = global_mean_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.mean(dim=0))
    assert torch.allclose(out[1], px2.mean(dim=0))

    out = global_max_pool(px, pbatch)
    assert out.size() == (2, 4)
    assert torch.allclose(out[0], px1.max(dim=0)[0])
    assert torch.allclose(out[1], px2.max(dim=0)[0])
Exemple #4
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x1 = self.conv1(x, edge_index)

        y_molecules = global_add_pool(x1, batch)
        z_molecules = self.gather_layer(y_molecules)
        return z_molecules
Exemple #5
0
 def forward(self, data):
     x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
     print('x ', x)
     print('edge_index ', edge_index)
     print('edge_attr ', edge_attr)
     x = F.relu(self.nnconv1(x, edge_index, edge_attr))
     print('x1 ', x.shape)
     x = self.bn1(x)
     print('x2 ', x.shape)
     x = F.relu(self.nnconv2(x, edge_index))
     print('x3 ', x.shape)
     x = self.bn2(x)
     print('x4 ', x.shape)
     x = global_add_pool(x, data.batch)
     print('x5 ', x.shape)
     x = F.relu(self.fc1(x))
     print('x6 ', x.shape)
     #  x = self.bn3(x)
     #   print('x7 ', x.shape)
     x = F.relu(self.fc2(x))
     print('x8 ', x.shape)
     x = F.dropout(x, p=0.2, training=self.training)
     print('x9 ', x.shape)
     x = self.fc3(x)
     print('x10 ', x.shape)
     x = F.log_softmax(x, dim=1)
     print('x11 ', x.shape)
     return x
Exemple #6
0
    def forward(self, batched_data):
        x, edge_index, node_depth, batch = batched_data.x, batched_data.edge_index,  batched_data.node_depth, batched_data.batch

        x = self.node_encoder(x, node_depth.view(-1, ))
        node_states_per_layer = []  # one entry per layer (final state of that layer), shape: number of nodes in batch v x D
        node_states_per_layer.append(x)

        for layer_idx, num_timesteps in enumerate(self.layer_timesteps):

            # Extract residual messages, if any:
            layer_residual_connections = self.residual_connections.get(str(layer_idx))
            layer_residual_states = [] if layer_residual_connections is None else \
                [node_states_per_layer[residual_layer_idx]
                                         for residual_layer_idx in layer_residual_connections]

            # Record new states for this layer. Initialised to last state, but will be updated below:
            node_states_layer = self.convs[layer_idx](node_states_per_layer[-1], edge_index, batched_data.edge_attr, layer_residual_states)
            node_states_per_layer.append(node_states_layer)

        hx = torch.cat([node_states_per_layer[-1], x], dim=-1)
        x = self.classifier_l(hx) * self.classifier_r(hx)
        output = global_add_pool(x, batch=batch)

        pred_list = []
        for i in range(self.max_seq_len):
            pred_list.append(self.graph_pred_linear_list[i](output))

        return pred_list
Exemple #7
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        target = data.target

        x = F.relu(self.conv1(x, edge_index))
        x = self.bn1(x)
        x = F.relu(self.conv2(x, edge_index))
        x = self.bn2(x)
        x = F.relu(self.conv3(x, edge_index))
        x = self.bn3(x)
        x = global_add_pool(x, batch)
        x = F.relu(self.fc1_xd(x))
        x = F.dropout(x, p=0.2, training=self.training)

        embedded_xt = self.embedding_xt(target)
        # flatten
        xt = embedded_xt.view(-1, 1000 * 128)
        xt = self.fc1_xt(xt)

        # concat
        xc = torch.cat((x, xt), 1)
        # add some dense layers
        xc = self.fc1(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        xc = self.fc2(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        out = self.out(xc)
        return out
Exemple #8
0
    def forward(self, z, edge_index, batch, x=None, edge_weight=None, node_id=None):
        z_emb = self.z_embedding(z)
        if z_emb.ndim == 3:  # in case z has multiple integer labels
            z_emb = z_emb.sum(dim=1)
        if self.use_feature and x is not None:
            x = torch.cat([z_emb, x.to(torch.float)], 1)
        else:
            x = z_emb
        if self.node_embedding is not None and node_id is not None:
            n_emb = self.node_embedding(node_id)
            x = torch.cat([x, n_emb], 1)
        for conv in self.convs[:-1]:
            x = conv(x, edge_index, edge_weight)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index, edge_weight)
        if True:  # center pooling
            _, center_indices = np.unique(batch.cpu().numpy(), return_index=True)
            x_src = x[center_indices]
            x_dst = x[center_indices + 1]
            x = (x_src * x_dst)
            x = F.relu(self.lin1(x))
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.lin2(x)
        else:  # sum pooling
            x = global_add_pool(x, batch)
            x = F.relu(self.lin1(x))
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.lin2(x)

        return x
Exemple #9
0
    def forward(self, batch):
        x, edge_index, batch_ids = batch.x, batch.edge_index, batch.batch

        out = None
        for _ in range(self.num_perm):
            new_x = torch.empty(x.size(0),
                                x.size(1) + self.fixed_size).to(x.device)
            for graph in range(torch.max(batch_ids).item() + 1):
                node_indices = (batch_ids == graph).nonzero().squeeze(1)

                graph_size = node_indices.size(0)
                perm = torch.randperm(graph_size)

                node_ids = self.__getattr__(f"node_ids").repeat(
                    graph_size // self.fixed_size + 1, 1)[:graph_size]
                permuted_node_ids = node_ids.to(x.device)[perm, :]
                new_x[node_indices] = torch.cat(
                    [x[node_indices], permuted_node_ids], dim=1)

            h_v = self.node_embedder.forward(new_x, edge_index)
            h_g = global_add_pool(h_v, batch_ids)

            if out is None:
                out = h_g / self.num_perm
            else:
                out += h_g / self.num_perm
        return out
    def forward(self, x, edge_index, batch, pretr=False):
        x1 = F.relu(self.conv1(x, edge_index))
        #x1 = F.dropout(x1, training=self.training)
        x2 = self.conv2(x1, edge_index)
        #x2 = F.dropout(x2, training=self.training)
        x3 = self.conv3(x2, edge_index)
        #return F.log_softmax(x, dim=1)
        x = torch.cat([x1, x2, x3], dim=1)
        x = F.relu(self.fc1(x))

        x = global_add_pool(x, batch)
        #print(x.shape)
        #x = F.relu(self.fc1a(x))
        #x = F.dropout(x, p=0.2, training=self.training)
        #x = self.fc2(x)
        if pretr:
            out1 = self.fc2(x)
        #else:
        #x= self.fc2(x)

        x = self.fc3(x)
        x = F.log_softmax(x, dim=-1)
        if pretr:
            return out1, x  #F.log_softmax(x, dim=-1)
        else:
            return x
    def forward(self, x, edge_index, batch, pretr=False):
        x = F.relu(self.conv1(x, edge_index))
        x = self.bn1(x)
        x = F.relu(self.conv2(x, edge_index))
        x = self.bn2(x)
        x = F.relu(self.conv3(x, edge_index))
        x = self.bn3(x)
        x = F.relu(self.conv4(x, edge_index))
        x = self.bn4(x)
        x = F.relu(self.conv5(x, edge_index))
        x = self.bn5(x)
        x = global_add_pool(x, batch)
        x = F.relu(self.fc1(x))
        #x = global_add_pool(x, batch)
        x = F.dropout(x, p=0.5, training=self.training)
        #x = self.fc3(x)

        if pretr:
            x = self.fc3(x)
        else:
            #x = F.dropout(x, p=0.5, training=self.training)
            x = self.fc2(x)
            x = F.log_softmax(x, dim=-1)

        return x
Exemple #12
0
    def forward(self, batched_data):
        x, edge_index, node_depth, batch = batched_data.x, batched_data.edge_index,  batched_data.node_depth, batched_data.batch

        x = self.node_encoder(x, node_depth.view(-1, ))
        node_states_per_layer = []  # one entry per layer (final state of that layer), shape: number of nodes in batch v x D
        node_states_per_layer.append(x)

        for layer_idx, num_timesteps in enumerate(self.layer_timesteps):

            # Record new states for this layer. Initialised to last state, but will be updated below:
            node_states_layer = self.convs[layer_idx](node_states_per_layer[-1], edge_index)
            node_states_per_layer.append(node_states_layer)

        hx = torch.cat([node_states_per_layer[-1], x], dim=-1)
        x = self.classifier_l(hx) * self.classifier_r(hx)
        output = global_add_pool(x, batch=batch)

        if self.num_class > 0:
            return self.graph_pred_linear(output)

        pred_list = []
        for i in range(self.max_seq_len):
            pred_list.append(self.graph_pred_linear_list[i](output))

        return pred_list
def pool_func(x, batch, mode="sum"):
    if mode == "sum":
        return global_add_pool(x, batch)
    elif mode == "mean":
        return global_mean_pool(x, batch)
    elif mode == "max":
        return global_max_pool(x, batch)
Exemple #14
0
    def forward(self, data):

        x, pos, batch, u = data.x, data.pos, data.batch, data.u

        # Get edges using positions by computing the kNNs or the neighbors within a radius
        #edge_index = knn_graph(pos, k=self.k_nn, batch=batch, loop=self.loop)
        edge_index = radius_graph(pos,
                                  r=self.k_nn,
                                  batch=batch,
                                  loop=self.loop)

        # Start message passing
        for layer in self.layers:
            if self.namemodel == "DeepSet":
                x = layer(x)
            elif self.namemodel == "PointNet":
                x = layer(x=x, pos=pos, edge_index=edge_index)
            elif self.namemodel == "MetaNet":
                x, dumb, u = layer(x, edge_index, None, u, batch)
            else:
                x = layer(x=x, edge_index=edge_index)
            self.h = x
            x = x.relu()

        # Mix different global pooling layers
        addpool = global_add_pool(x, batch)  # [num_examples, hidden_channels]
        meanpool = global_mean_pool(x, batch)
        maxpool = global_max_pool(x, batch)
        #self.pooled = torch.cat([addpool, meanpool, maxpool], dim=1)
        self.pooled = torch.cat([addpool, meanpool, maxpool, u], dim=1)

        # Final linear layer
        return self.lin(self.pooled)
Exemple #15
0
    def forward(self, data):
        x1 = F.relu(self.conv1(data.x, data.edge_index))
        x1 = self.bn1(x1)
        # x1_g = global_add_pool(x1, data.batch)

        x2 = F.relu(self.conv2(x1, data.edge_index))
        x2 = self.bn2(x2)
        # x2_g = global_add_pool(x2, data.batch)

        x3 = F.relu(self.conv3(x2, data.edge_index))
        x3 = self.bn3(x3)
        # x3_g = global_add_pool(x3, data.batch)

        x4 = F.relu(self.conv4(x3, data.edge_index))
        x4 = self.bn4(x4)
        # x4_g = global_add_pool(x4, data.batch)

        x5 = F.relu(self.conv5(x4, data.edge_index))
        x5 = self.bn5(x5)
        x5_g = global_add_pool(x5, data.batch)

        # x = torch.cat([x1_g, x2_g, x3_g, x4_g, x5_g], dim=-1)
        x = F.relu(self.fc1(x5_g))
        x = self.fc2(x)
        return x.view(-1)
Exemple #16
0
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        target = data.target
        u = F.relu(self.conv1(x, edge_index))
        u = self.bn1(u)
        u = F.relu(self.conv2(u, edge_index))
        u = self.bn2(u)
        u = F.relu(self.conv3(u, edge_index))
        u = self.bn3(u)
        u = F.relu(self.conv4(u, edge_index))
        u = self.bn4(u)
        u = F.relu(self.conv5(u, edge_index))
        u = self.bn5(u)
        u = global_add_pool(u, batch)
        u = F.relu(self.fc1_xd(u))
        u = F.dropout(u, p=0.2, training=self.training)

        embedded_xt = self.embedding_xt(target)
        conv_xt = self.conv_xt_1(embedded_xt)
        xt = conv_xt.view(-1, 32 * 121)
        xt = self.fc1_xt(xt)

        xc = torch.cat((u, xt), 1)
        xc = self.fc1(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        xc = self.fc2(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        out = self.out(xc)
        return out
Exemple #17
0
    def forward(self, batched_data):

        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch

        ### virtual node embeddings for graphs
        virtualnode_embedding = self.virtualnode_embedding(
            torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(
                edge_index.device))

        h_list = [self.atom_encoder(x)]
        for layer in range(self.num_layer):
            ### add message from virtual nodes to graph nodes
            h_list[layer] = h_list[layer] + virtualnode_embedding[batch]

            ### Message passing among graph nodes
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)

            h = self.batch_norms[layer](h)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h),
                              self.drop_ratio,
                              training=self.training)

            if self.residual:
                h = h + h_list[layer]

            h_list.append(h)

            ### update the virtual nodes
            if layer < self.num_layer - 1:
                ### add message from graph nodes to virtual nodes
                virtualnode_embedding_temp = global_add_pool(
                    h_list[layer], batch) + virtualnode_embedding
                ### transform virtual nodes using MLP

                if self.residual:
                    virtualnode_embedding = virtualnode_embedding + F.dropout(
                        self.mlp_virtualnode_list[layer]
                        (virtualnode_embedding_temp),
                        self.drop_ratio,
                        training=self.training)
                else:
                    virtualnode_embedding = F.dropout(
                        self.mlp_virtualnode_list[layer](
                            virtualnode_embedding_temp),
                        self.drop_ratio,
                        training=self.training)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer):
                node_representation += h_list[layer]

        return node_representation
Exemple #18
0
    def forward(self, x, edge_index, edge_attr, batch):
        """"""
        # Atom Embedding:
        x = F.leaky_relu_(self.lin1(x))

        h = F.elu_(self.atom_convs[0](x, edge_index, edge_attr))
        h = F.dropout(h, p=self.dropout, training=self.training)
        x = self.atom_grus[0](h, x).relu_()

        for conv, gru in zip(self.atom_convs[1:], self.atom_grus[1:]):
            h = F.elu_(conv(x, edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = gru(h, x).relu_()

        # Molecule Embedding:
        row = torch.arange(batch.size(0), device=batch.device)
        edge_index = torch.stack([row, batch], dim=0)

        out = global_add_pool(x, batch).relu_()
        for t in range(self.num_timesteps):
            h = F.elu_(self.mol_conv((x, out), edge_index))
            h = F.dropout(h, p=self.dropout, training=self.training)
            out = self.mol_gru(h, out).relu_()

        # Predictor:
        out = F.dropout(out, p=self.dropout, training=self.training)
        return self.lin2(out)
    def get_distribution_parameters(self, node_embeddings, batch):

        if self.aggregate is not None:
            graph_embeddings = self.aggregate(
                self.node_transform(node_embeddings), batch)
            out = self.output_activation(
                self.final_transform(graph_embeddings))
        else:
            out = self.output_activation(
                self.final_transform(self.node_transform(node_embeddings)))

        if 'binomial' in self.output_type:
            params = torch.reshape(
                out, [-1, self.no_experts, 2])  # ? x no_experts x K
            # first parameter not used here
            _, p = torch.round(torch.relu(params[:, :, 0])) + 1, torch.sigmoid(
                params[:, :, 1])
            n = global_add_pool(
                torch.ones(node_embeddings.shape[0],
                           self.no_experts).to(node_embeddings.device), batch)

            distr_params = (n, p)
        elif 'gaussian' in self.output_type:
            # Assume isotropic gaussians
            params = torch.reshape(out,
                                   [-1, self.no_experts, 2, self.dim_target
                                    ])  # ? x no_experts x 2 x F
            mu, var = params[:, :, 0, :], params[:, :, 1, :]

            var = torch.nn.functional.softplus(var) + 1e-8
            # F is assumed to be 1 for now, add dimension to F

            distr_params = (mu, var)  # each has shape ? x no_experts X F

        return distr_params
 def forward(self, x, edge_index, batch):
     x = self.conv1(x, edge_index)
     x = F.relu(x)
     x = self.conv2(x, edge_index)
     x = global_add_pool(x, batch)
     x = self.lin(x)
     return x.log_softmax(dim=1)
    def forward(self, data, batch_size=None, **kwargs):
        x = data.x
        batch = data.batch
        edge_index = data.edge_index
        pos = data.pos

        # infer real batch size, in case empty sample
        if batch_size is None:
            batch_size = data['size'].sum().item()

        img_feature = self.encoder(x).flatten(1)
        x = torch.cat([img_feature, pos], dim=1)

        x = self.gnn(x=x, edge_index=edge_index)
        x = self.encoder2(x)

        if self.global_aggr == 'max':
            global_feature = gnn.global_max_pool(x, batch, size=batch_size)
        elif self.global_aggr == 'sum':
            global_feature = gnn.global_add_pool(x, batch, size=batch_size)
        else:
            raise NotImplementedError()

        logits = self.fc(global_feature)

        out_dict = {
            'logits': logits,
        }
        return out_dict
Exemple #22
0
    def forward(
        self, data: 'torch_geometric.data.Data'
    ) -> Tuple['torch.tensor', 'torch.tensor', 'torch.tensor']:
        """
        torch.nn.module forward operation

        Args:
            data (torch_geometric.data.Data): data to be fed forward; must have
                node attributes, edge attributes, edge index defined

        Returns:
            Tuple[torch.tensor, torch.tensor, torch.tensor]: (GCN output, node
                embeddings, edge embeddings)
        """

        # Get batch
        x, edge_attr, edge_index, batch = data.x, data.edge_attr,\
            data.edge_index, data.batch
        row, col = edge_index
        if data.num_node_features == 0:
            x = torch.ones(data.num_nodes, 1)

        out = F.relu(self.lin0(x))
        out_edge = F.relu(self.lin0_edge(edge_attr))
        h = out.unsqueeze(0)
        h_edge = out_edge.unsqueeze(0)

        # Feed forward, node and edge messages
        for i in range(self._n_messages):
            m = F.relu(self.node_conv(out, edge_index))
            emb_node = m
            m = F.dropout(m, p=self._dropout, training=self.training)
            out, h = self.node_gru(m.unsqueeze(0), h)
            out = out.squeeze(0)

            m_edge = F.relu(self.edge_conv(out_edge, edge_index))
            emb_edge = m_edge
            m_edge = F.dropout(m_edge, p=self._dropout, training=self.training)
            out_edge, h_edge = self.edge_gru(m_edge.unsqueeze(0), h_edge)
            out_edge = out_edge.squeeze(0)

        # Concatenate node network and edge network output tensors
        out = torch.cat([out[row], out_edge[col]], dim=1)

        # Perform scatter add, reshape to original node dimensionality
        out = scatter_add(out, col, dim=0, dim_size=x.size(0))

        # Perform summation over all nodes w.r.t. current batch
        out = pyg_nn.global_add_pool(out, batch)

        # Perform post-message passing feed forward operations
        for layer in self.post_conv[:-1]:
            out = layer(out)
            out = F.relu(out)
            out = F.dropout(out, p=self._dropout, training=self.training)
        out = self.post_conv[-1](out)

        # Return fed-forward data, node embedding, edge embedding
        return out, emb_node, emb_edge
    def forward(self, data):
        #if data.x is None:
        #    data.x = torch.ones((data.num_nodes, 1), device=utils.get_device())

        #x = self.pre_mp(x)
        if self.feat_preprocess is not None:
            if not hasattr(data, "preprocessed"):
                data = self.feat_preprocess(data)
                data.preprocessed = True
        if 'aifb' == 'aifb' or 'wn18' == 'wn18':
            x, edge_index, batch, edge_type = data.node_feature, data.edge_index, data.batch, data.edge_feature
            edge_type = edge_type.reshape(-1)
            x = self.pre_mp(x)
        else:
            x, edge_index, batch = data.node_feature, data.edge_index, data.batch
            x = self.pre_mp(x)

        all_emb = x.unsqueeze(1)
        emb = x
        for i in range(
                len(self.convs_sum) if self.conv_type ==
                "PNA" else len(self.convs)):
            if self.skip == 'learnable':
                skip_vals = self.learnable_skip[i, :i +
                                                1].unsqueeze(0).unsqueeze(-1)
                curr_emb = all_emb * torch.sigmoid(skip_vals)
                curr_emb = curr_emb.view(x.size(0), -1)
                if self.conv_type == "PNA":
                    x = torch.cat((self.convs_sum[i](curr_emb, edge_index),
                                   self.convs_mean[i](curr_emb, edge_index),
                                   self.convs_max[i](curr_emb, edge_index)),
                                  dim=-1)
                elif self.conv_type == "RGCN":
                    # edge_type_ = torch.randint_like(edge_index, low=0, high=2)[1].detach().to(edge_index.device)
                    x = self.convs[i](curr_emb, edge_index, edge_type)
                else:
                    x = self.convs[i](curr_emb, edge_index)
            elif self.skip == 'all':
                if self.conv_type == "PNA":
                    x = torch.cat((self.convs_sum[i](
                        emb, edge_index), self.convs_mean[i](emb, edge_index),
                                   self.convs_max[i](emb, edge_index)),
                                  dim=-1)
                else:
                    x = self.convs[i](emb, edge_index)
            else:
                x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            emb = torch.cat((emb, x), 1)
            if self.skip == 'learnable':
                all_emb = torch.cat((all_emb, x.unsqueeze(1)), 1)

        # x = pyg_nn.global_mean_pool(x, batch)
        emb = pyg_nn.global_add_pool(emb, batch)
        emb = self.post_mp(emb)
        #emb = self.batch_norm(emb)   # TODO: test
        #out = F.log_softmax(emb, dim=1)
        return emb
Exemple #24
0
    def forward(self, data):
        data.x = self.conv(data.x, data.edge_index)
        att_x, att_edge_index, att_edge_attr, att_batch, att_perm, att_scores = self.readout(
            data.x, data.edge_index, batch=data.batch)
        global_graph_emb = global_add_pool(att_x, att_batch)

        # data = max_pool_neighbor_x(data)
        return data, global_graph_emb
Exemple #25
0
    def forward(self, anchor_batch, negative_batch, positive_batch,
                anchor: Tensor, negative: Tensor, positive: Tensor,
                anchor_gt: Tensor, negative_gt: Tensor,
                positive_gt: Tensor) -> Tensor:
        anchor = global_add_pool(anchor, anchor_batch)

        positive = global_add_pool(positive, positive_batch)

        negative = global_add_pool(negative, negative_batch)

        pos_distance = torch.linalg.norm(positive - anchor, dim=1)
        negative_distance = torch.linalg.norm(negative - anchor, dim=1)

        coeff = torch.div(torch.abs(negative_gt - anchor_gt),
                          (torch.abs(positive_gt - anchor_gt) + self.eps))
        loss = F.relu((pos_distance - coeff * negative_distance) + self.margin)
        return torch.mean(loss)
Exemple #26
0
 def forward(self, data):
     subgraph_data = subgraph_loader( data, k, super_node_size, num_tours, num_cpus )
     subgraphs = [get_subgraph(data[subgraph_data.batch[i].item()], subgraph_data.subgraphs[i].squeeze()) for i in range(len(subgraph_data.subgraphs))]
     subgraphs_lst = []
     for i in range(0, len(subgraphs), 500):
         subgraphs_b =  Batch().from_data_list(subgraphs[i:i+min([500,len(subgraphs)-i])])
         subgraphs_b = self.gnn_layer(subgraphs_b.x.cuda(), subgraphs_b.edge_index.cuda(), subgraphs_b.batch.cuda()) \
         if next(self.parameters()).get_device() != -1 else self.gnn_layer(subgraphs_b.x, subgraphs_b.edge_index, subgraphs_b.batch)
         subgraphs_lst.append(subgraphs_b)
     subgraphs = torch.cat(subgraphs_lst,dim=0)
     subgraphs = self.output_layer(subgraphs)
     weights = subgraph_data.weights.cuda() if next(self.parameters()).get_device() != -1 else subgraph_data.weights
     batch = subgraph_data.batch.cuda() if next(self.parameters()).get_device() != -1 else subgraph_data.batch
     subgraphs = subgraphs*weights
     norm = global_add_pool(weights, batch)
     energy = global_add_pool(subgraphs, batch)
     return energy/norm
Exemple #27
0
 def forward(self, data):
     x, edge_index, batch = data.x, data.edge_index, data.batch
     x = F.relu(self.conv1(x, edge_index))
     xs = [global_add_pool(x, batch)]
     for i, conv in enumerate(self.convs):
         x = F.relu(conv(x, edge_index))
         xs += [global_add_pool(x, batch)]
         if i % 2 == 0 and i < len(self.convs) - 1:
             pool = self.pools[i // 2]
             x, edge_index, _, batch, _, _ = pool(x,
                                                  edge_index,
                                                  batch=batch)
     x = self.jump(xs)
     x = F.relu(self.lin1(x))
     x = F.dropout(x, p=0.5, training=self.training)
     x = self.lin2(x)
     return F.log_softmax(x, dim=-1)
Exemple #28
0
 def forward(self, x, edge_index, batch):
     for conv, batch_norm in zip(self.convs, self.batch_norms):
         x = F.relu(batch_norm(conv(x, edge_index)))
     x = global_add_pool(x, batch)
     x = F.relu(self.batch_norm1(self.lin1(x)))
     x = F.dropout(x, p=0.5, training=self.training)
     x = self.lin2(x)
     return F.log_softmax(x, dim=-1)
def test_permuted_global_pool():
    N_1, N_2 = 4, 6
    x = torch.randn(N_1 + N_2, 4)
    batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long)
    perm = torch.randperm(N_1 + N_2)

    out_1 = global_add_pool(x, batch)
    out_2 = global_add_pool(x[perm], batch[perm])
    assert torch.allclose(out_1, out_2)

    out_1 = global_mean_pool(x, batch)
    out_2 = global_mean_pool(x[perm], batch[perm])
    assert torch.allclose(out_1, out_2)

    out_1 = global_max_pool(x, batch)
    out_2 = global_max_pool(x[perm], batch[perm])
    assert torch.allclose(out_1, out_2)
Exemple #30
0
    def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_emb(x.squeeze())
        edge_attr = self.edge_emb(edge_attr)

        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = F.relu(batch_norm(conv(x, edge_index, edge_attr)))

        x = global_add_pool(x, batch)
        return self.mlp(x)