Example #1
0
 def __init__(self, dim_in):
     super().__init__()
     self.dim_in = dim_in
     if cfg.dataset.node_encoder:
         # Encode integer node features via nn.Embeddings
         NodeEncoder = register.node_encoder_dict[
             cfg.dataset.node_encoder_name]
         self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
         if cfg.dataset.node_encoder_bn:
             self.node_encoder_bn = BatchNorm1dNode(
                 new_layer_config(cfg.gnn.dim_inner,
                                  -1,
                                  -1,
                                  has_act=False,
                                  has_bias=False,
                                  cfg=cfg))
         # Update dim_in to reflect the new dimension fo the node features
         self.dim_in = cfg.gnn.dim_inner
     if cfg.dataset.edge_encoder:
         # Encode integer edge features via nn.Embeddings
         EdgeEncoder = register.edge_encoder_dict[
             cfg.dataset.edge_encoder_name]
         self.edge_encoder = EdgeEncoder(cfg.gnn.dim_inner)
         if cfg.dataset.edge_encoder_bn:
             self.edge_encoder_bn = BatchNorm1dNode(
                 new_layer_config(cfg.gnn.dim_inner,
                                  -1,
                                  -1,
                                  has_act=False,
                                  has_bias=False,
                                  cfg=cfg))
Example #2
0
 def __init__(self, dim_in, dim_out):
     super(GNNEdgeHead, self).__init__()
     # module to decode edges from node embeddings
     if cfg.model.edge_decoding == 'concat':
         self.layer_post_mp = MLP(
             new_layer_config(dim_in * 2, dim_out, cfg.gnn.layers_post_mp,
                              has_act=False, has_bias=True, cfg=cfg))
         # requires parameter
         self.decode_module = lambda v1, v2: \
             self.layer_post_mp(torch.cat((v1, v2), dim=-1))
     else:
         if dim_out > 1:
             raise ValueError(
                 'Binary edge decoding ({})is used for multi-class '
                 'edge/link prediction.'.format(cfg.model.edge_decoding))
         self.layer_post_mp = MLP(
             new_layer_config(dim_in, dim_in, cfg.gnn.layers_post_mp,
                              has_act=False, has_bias=True, cfg=cfg))
         if cfg.model.edge_decoding == 'dot':
             self.decode_module = lambda v1, v2: torch.sum(v1 * v2, dim=-1)
         elif cfg.model.edge_decoding == 'cosine_similarity':
             self.decode_module = nn.CosineSimilarity(dim=-1)
         else:
             raise ValueError('Unknown edge decoding {}.'.format(
                 cfg.model.edge_decoding))
Example #3
0
 def __init__(self, dim_in, dim_out):
     super().__init__()
     self.layer_post_mp = MLP(
         new_layer_config(dim_in,
                          dim_out,
                          cfg.gnn.layers_post_mp,
                          has_act=False,
                          has_bias=True,
                          cfg=cfg))
Example #4
0
 def __init__(self, dim_in, dim_out):
     super().__init__()
     self.layer_post_mp = MLP(
         new_layer_config(dim_in,
                          dim_out,
                          cfg.gnn.layers_post_mp,
                          has_act=False,
                          has_bias=True,
                          cfg=cfg))
     self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling]
Example #5
0
def GNNPreMP(dim_in, dim_out):
    """
    Wrapper for NN layer before GNN message passing

    Args:
        dim_in (int): Input dimension
        dim_out (int): Output dimension

    """
    return GeneralMultiLayer(
        'linear',
        layer_config=new_layer_config(dim_in, dim_out, 1, has_act=False,
                                      has_bias=False, cfg=cfg))
Example #6
0
def GNNLayer(dim_in, dim_out, has_act=True):
    """
    Wrapper for a GNN layer

    Args:
        dim_in (int): Input dimension
        dim_out (int): Output dimension
        has_act (bool): Whether has activation function after the layer

    """
    return GeneralLayer(
        cfg.gnn.layer_type,
        layer_config=new_layer_config(dim_in, dim_out, 1, has_act=has_act,
                                      has_bias=False, cfg=cfg))