Exemplo n.º 1
0
    def __init__(self, dataset, predict):
        super(GNN, self).__init__()

        assert dataset in ('prompt_new', 'displaced', 'prompt_old')
        assert predict in ('pT', '1/pT', 'pT_classes')

        self.predict = predict
        self.dataset = dataset

        if self.dataset in ['prompt_new', 'displaced']:
            self.conv1 = MPL(4, 128)
        if self.dataset == 'prompt_old':
            self.conv1 = MPL(3, 128)

        self.conv2 = MPL(128, 32)
        self.conv3 = MPL(32, 64)
        self.conv4 = MPL(64, 32)
        self.lin1 = torch.nn.Linear(32 * 2, 128)
        self.lin2 = torch.nn.Linear(128, 16)
        self.lin3 = torch.nn.Linear(16, 16)

        if self.predict == 'pT':
            self.lin4 = torch.nn.Linear(16, 1)
        if self.predict == '1/pT':
            self.lin4 = torch.nn.Linear(16, 1)
        if self.predict == 'pT_classes':
            self.lin4 = torch.nn.Linear(16, 4)
        self.global_att_pool1 = gnn.GlobalAttention(
            torch.nn.Sequential(torch.nn.Linear(32, 1)))
        self.global_att_pool2 = gnn.GlobalAttention(
            torch.nn.Sequential(torch.nn.Linear(32, 1)))
Exemplo n.º 2
0
    def __init__(self,
                 node_attr_dim: int,
                 edge_attr_dim: int,
                 state_dim: int = 64,
                 num_conv: int = 3,
                 out_dim: int = 1,
                 attention_pooling: bool = False):

        super(MPNN, self).__init__()

        self.__in_linear = nn.Sequential(nn.Linear(node_attr_dim, state_dim),
                                         nn.ReLU())

        self.__num_conv = num_conv
        self.__nn_conv_linear = nn.Sequential(
            nn.Linear(edge_attr_dim, state_dim), nn.ReLU(),
            nn.Linear(state_dim, state_dim * state_dim))
        self.__nn_conv = pyg_nn.NNConv(state_dim,
                                       state_dim,
                                       self.__nn_conv_linear,
                                       aggr='mean',
                                       root_weight=False)
        self.__gru = nn.GRU(state_dim, state_dim)

        # self.__set2set = pyg_nn.Set2Set(state_dim, processing_steps=3)
        if attention_pooling:
            self.__pooling = pyg_nn.GlobalAttention(
                nn.Linear(state_dim, 1), nn.Linear(state_dim, 2 * state_dim))
        else:
            # Setting the num_layers > 1 will take significantly more time
            self.__pooling = pyg_nn.Set2Set(state_dim, processing_steps=3)

        self.__out_linear = nn.Sequential(
            nn.Linear(2 * state_dim, 2 * state_dim), nn.ReLU(),
            nn.Linear(2 * state_dim, out_dim))
Exemplo n.º 3
0
Arquivo: gat.py Projeto: xduan7/MoReL
    def __init__(self,
                 node_attr_dim: int,
                 edge_attr_dim: int,
                 state_dim: int = 8,
                 num_heads: int = 8,
                 num_conv: int = 2,
                 out_dim: int = 1,
                 dropout: float = 0.2,
                 attention_pooling: bool = True):

        super(EdgeGATEncoder, self).__init__()

        self.__edge_gat = EdgeGAT(node_attr_dim=node_attr_dim,
                                  edge_attr_dim=edge_attr_dim,
                                  state_dim=state_dim,
                                  num_heads=num_heads,
                                  num_conv=num_conv,
                                  out_dim=state_dim,
                                  dropout=dropout)

        # Pooling layer is supposed to perform the following shape-shifting:
        #   From [num_nodes, node_attr_dim * edge_attr_dim]
        #   To [num_graphs, 2 * state_dim * edge_attr_dim]
        if attention_pooling:
            self.__pooling = pyg_nn.GlobalAttention(
                nn.Linear(state_dim * edge_attr_dim, 1),
                nn.Linear(state_dim * edge_attr_dim,
                          2 * state_dim * edge_attr_dim))
        else:
            self.__pooling = pyg_nn.Set2Set(state_dim * edge_attr_dim,
                                            processing_steps=3)

        self.__out_linear = nn.Sequential(
            nn.Linear(2 * state_dim * edge_attr_dim, state_dim), nn.ReLU(),
            nn.Linear(state_dim, out_dim))
Exemplo n.º 4
0
    def __init__(self, input_dim, hidden_dim, output_dim, args, task='node'):
        super(GNNStack, self).__init__()
        self.input_dim = input_dim
        conv_model = self.build_conv_model(args.conv_type)
        self.convs = nn.ModuleList()
        self.convs.append(conv_model(input_dim, hidden_dim))
        self.conv_type = args.conv_type
        assert (args.n_layers >= 1), 'Number of layers is not >=1'
        for l in range(args.n_layers - 1):
            self.convs.append(conv_model(hidden_dim, hidden_dim))

        if self.conv_type == "gated":
            self.glob_soft_attn = pyg_nn.GlobalAttention(
                nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                              nn.Linear(hidden_dim, 1)))

        # post-message-passing
        self.post_mp = nn.Sequential(
            nn.Linear(args.n_layers * hidden_dim, hidden_dim),
            nn.Dropout(args.dropout), nn.Linear(hidden_dim, output_dim))

        self.task = task
        if not (self.task == 'node' or self.task == 'graph'):
            raise RuntimeError('Unknown task.')

        self.dropout = args.dropout
        self.num_layers = args.n_layers
        self.hidden_dim = hidden_dim
Exemplo n.º 5
0
 def __init__(self, feat_in, hidden_features=10, treesup_out=20, dropout=0.1):
     """
     A class for whole tree assessment. Contains two convolution layers
     :param feat_in: Number of features of one node
     :param hidden_features: Number of features between layers
     """
     super(WholeTreeAssessor, self).__init__()
     self.tree_conv = TreeSupport(feat_in, hidden_features=hidden_features, output_dim=treesup_out, dropout=dropout)
     self.gpl = gnn.GlobalAttention(
         nn.Sequential(nn.Linear(treesup_out, treesup_out),
                       nn.Linear(treesup_out, 1))
     )
     self.lin = nn.Sequential(nn.Linear(treesup_out, 1), nn.ReLU())
Exemplo n.º 6
0
 def build_pool(self, in_channels):
     return geom_nn.GlobalAttention(nn.Linear(in_channels.node, 1))