def forward(self, mol_graph: dgl.DGLGraph, pairwise_indices: torch.Tensor):
        # get embeddings
        self.gnn(mol_graph)

        # apply down projection to embeddings if we are not using a distance net and projection_dim > 0
        if self.node_projection_net and not self.distance_net:
            mol_graph.apply_nodes(self.node_projection)

        # put the embeddings h from the same graph in the batched graph into pairs for the distance net to predict the pairwise distances
        h = mol_graph.ndata['feat']
        src_h = torch.index_select(h, dim=0, index=pairwise_indices[0])
        dst_h = torch.index_select(h, dim=0, index=pairwise_indices[1])

        # for debugging:
        # x = mol_graph.ndata['x']
        # src_x = torch.index_select(x, dim=0, index=pairwise_indices[0])
        # dst_x = torch.index_select(x, dim=0, index=pairwise_indices[1])
        # ic(torch.norm(src_x-dst_x, dim=-1))

        if self.distance_net:
            src_dst_h = torch.cat([src_h, dst_h], dim=1)
            distances = self.distance_net(src_dst_h)
        else:
            distances = torch.norm(src_h-dst_h, dim=-1).unsqueeze(-1)

        return distances
Пример #2
0
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)
        graph.apply_edges(self.input_edge_func, etype='bond')

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        graph.apply_nodes(self.output_node_func)
Пример #3
0
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        mean_nodes = dgl.mean_nodes(graph, 'feat')
        max_nodes = dgl.max_nodes(graph, 'feat')
        mean_max = torch.cat([mean_nodes, max_nodes], dim=-1)
        return self.output(mean_max)
Пример #4
0
 def forward(self, graph: dgl.DGLGraph):
     graph.apply_nodes(self.input_node_func)
     h = graph.ndata["f"]
     ef = graph.edata["w"]
     for mp_layer in self.mp_layers:
         h_in = h
         h = mp_layer(graph, h, ef)
         h = h + h_in
     graph.ndata["f"] = h
     graph.edata["w"] = ef
Пример #5
0
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        graph.apply_nodes(self.output_node_func)
        sum_nodes = dgl.sum_nodes(graph, 'feat')
        max_nodes = dgl.max_nodes(graph, 'feat')
        sum_max = torch.cat([sum_nodes, max_nodes], dim=-1)
        mol_property = self.output_network(sum_max)
        return mol_property
Пример #6
0
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)
        if self.fourier_encodings > 0:
            graph.edata['d'] = fourier_encode_dist(
                graph.edata['d'], num_encodings=self.fourier_encodings)
        graph.apply_edges(self.input_edge_func)

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        if self.node_wise_output_layers > 0:
            graph.apply_nodes(self.output_node_func)
        return graph.ndata['feat']
Пример #7
0
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)
        if self.fourier_encodings > 0:
            graph.edata['d'] = fourier_encode_dist(graph.edata['d'], num_encodings=self.fourier_encodings).squeeze()

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        graph.apply_nodes(self.output_node_func)

        readouts_to_cat = [dgl.readout_nodes(graph, 'feat', op=aggr) for aggr in self.readout_aggregators]
        readout = torch.cat(readouts_to_cat, dim=-1)
        return self.output(readout)
Пример #8
0
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        graph.apply_nodes(self.output_node_func)

        readouts_to_cat = [
            dgl.readout_nodes(graph, 'feat', op=aggr)
            for aggr in self.readout_aggregators
        ]
        readout = torch.cat(readouts_to_cat, dim=-1)
        return self.output(readout)
    def forward(self, graph: dgl.DGLGraph):
        if self.fourier_encodings > 0:
            graph.edata['d'] = fourier_encode_dist(
                graph.edata['d'], num_encodings=self.fourier_encodings)
        graph.apply_edges(self.input_edge_func)

        graph.update_all(message_func=self.message_function,
                         reduce_func=self.reduce_func(msg='m', out='m_sum'))

        if self.node_wise_output_layers > 0:
            graph.apply_nodes(self.output_node_func)

        readouts_to_cat = [
            dgl.readout_nodes(graph, 'feat', op=aggr)
            for aggr in self.readout_aggregators
        ]
        readout = torch.cat(readouts_to_cat, dim=-1)
        return self.output(readout)
Пример #10
0
    def forward(self, graph: dgl.DGLGraph):
        graph.ndata['feat'] = self.node_embedding[None, :].expand(
            graph.number_of_nodes(), -1)

        if self.fourier_encodings > 0:
            graph.edata['d'] = fourier_encode_dist(
                graph.edata['d'], num_encodings=self.fourier_encodings)
        graph.apply_edges(self.input_edge_func)

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        if self.node_wise_output_layers > 0:
            graph.apply_nodes(self.output_node_func)

        readouts_to_cat = [
            dgl.readout_nodes(graph, 'feat', op=aggr)
            for aggr in self.readout_aggregators
        ]
        readout = torch.cat(readouts_to_cat, dim=-1)
        return self.output(readout)
Пример #11
0
    def forward(self, graph: dgl.DGLGraph,
                feature: torch.Tensor) -> torch.Tensor:
        """GraphConvolutionalNetwork forward propagate method.

        Returns:
            torch.Tensor: Extracted graph representations.
        """
        # Initialize the node features with h.
        # ndata returns the data view of all the nodes.
        graph.ndata['h'] = feature

        # Send messages through all edges and update all nodes.
        # Additionally, apply a function `reduce`, which takes an average
        # over all neighbor node features, to update the node features
        # after receive.
        graph.update_all(msg, reduce)

        # Apply node func, i.e. update the node feature h_v by NodeApplyModule.
        graph.apply_nodes(func=self.apply_mod)

        # Return the end of hidden value.
        return graph.ndata.pop('h')
Пример #12
0
    def forward(self, graph: dgl.DGLGraph,
                feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Args:
            graph: the graph
            feats: node features with node type as key and the corresponding
                features as value. Each tensor is of shape (N, D) where N is the number
                of nodes of the corresponding node type, and D is the feature size.

        Returns:
            updated node features. Each tensor is of shape (N, D) where N is the number
            of nodes of the corresponding node type, and D is the feature size.

        """
        graph = graph.local_var()

        # assign data
        for nt, ft in feats.items():
            graph.nodes[nt].data.update({"ft": ft})

        for et in self.etypes:
            # option 1
            graph[et].update_all(fn.copy_u("ft", "m"),
                                 fn.mean("m", "mean"),
                                 etype=et)
            graph[et].update_all(fn.copy_u("ft", "m"),
                                 fn.max("m", "max"),
                                 etype=et)

            nt = et[2]
            graph.apply_nodes(self._concatenate_node_feat, ntype=nt)

            # copy update feature from new_ft to ft
            graph.nodes[nt].data.update({"ft": graph.nodes[nt].data["new_ft"]})

        return {nt: graph.nodes[nt].data["ft"] for nt in feats}
Пример #13
0
 def forward(self, graph: dgl.DGLGraph):
     self.node_gnn(graph)
     graph.apply_nodes(self.projection)
     return graph.ndata['feat']
Пример #14
0
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)

        for mp_layer in self.mp_layers:
            mp_layer(graph)