def __init__(self, dataset, predict): super(GNN, self).__init__() assert dataset in ('prompt_new', 'displaced', 'prompt_old') assert predict in ('pT', '1/pT', 'pT_classes') self.predict = predict self.dataset = dataset if self.dataset in ['prompt_new', 'displaced']: self.conv1 = MPL(4, 128) if self.dataset == 'prompt_old': self.conv1 = MPL(3, 128) self.conv2 = MPL(128, 32) self.conv3 = MPL(32, 64) self.conv4 = MPL(64, 32) self.lin1 = torch.nn.Linear(32 * 2, 128) self.lin2 = torch.nn.Linear(128, 16) self.lin3 = torch.nn.Linear(16, 16) if self.predict == 'pT': self.lin4 = torch.nn.Linear(16, 1) if self.predict == '1/pT': self.lin4 = torch.nn.Linear(16, 1) if self.predict == 'pT_classes': self.lin4 = torch.nn.Linear(16, 4) self.global_att_pool1 = gnn.GlobalAttention( torch.nn.Sequential(torch.nn.Linear(32, 1))) self.global_att_pool2 = gnn.GlobalAttention( torch.nn.Sequential(torch.nn.Linear(32, 1)))
def __init__(self, node_attr_dim: int, edge_attr_dim: int, state_dim: int = 64, num_conv: int = 3, out_dim: int = 1, attention_pooling: bool = False): super(MPNN, self).__init__() self.__in_linear = nn.Sequential(nn.Linear(node_attr_dim, state_dim), nn.ReLU()) self.__num_conv = num_conv self.__nn_conv_linear = nn.Sequential( nn.Linear(edge_attr_dim, state_dim), nn.ReLU(), nn.Linear(state_dim, state_dim * state_dim)) self.__nn_conv = pyg_nn.NNConv(state_dim, state_dim, self.__nn_conv_linear, aggr='mean', root_weight=False) self.__gru = nn.GRU(state_dim, state_dim) # self.__set2set = pyg_nn.Set2Set(state_dim, processing_steps=3) if attention_pooling: self.__pooling = pyg_nn.GlobalAttention( nn.Linear(state_dim, 1), nn.Linear(state_dim, 2 * state_dim)) else: # Setting the num_layers > 1 will take significantly more time self.__pooling = pyg_nn.Set2Set(state_dim, processing_steps=3) self.__out_linear = nn.Sequential( nn.Linear(2 * state_dim, 2 * state_dim), nn.ReLU(), nn.Linear(2 * state_dim, out_dim))
def __init__(self, node_attr_dim: int, edge_attr_dim: int, state_dim: int = 8, num_heads: int = 8, num_conv: int = 2, out_dim: int = 1, dropout: float = 0.2, attention_pooling: bool = True): super(EdgeGATEncoder, self).__init__() self.__edge_gat = EdgeGAT(node_attr_dim=node_attr_dim, edge_attr_dim=edge_attr_dim, state_dim=state_dim, num_heads=num_heads, num_conv=num_conv, out_dim=state_dim, dropout=dropout) # Pooling layer is supposed to perform the following shape-shifting: # From [num_nodes, node_attr_dim * edge_attr_dim] # To [num_graphs, 2 * state_dim * edge_attr_dim] if attention_pooling: self.__pooling = pyg_nn.GlobalAttention( nn.Linear(state_dim * edge_attr_dim, 1), nn.Linear(state_dim * edge_attr_dim, 2 * state_dim * edge_attr_dim)) else: self.__pooling = pyg_nn.Set2Set(state_dim * edge_attr_dim, processing_steps=3) self.__out_linear = nn.Sequential( nn.Linear(2 * state_dim * edge_attr_dim, state_dim), nn.ReLU(), nn.Linear(state_dim, out_dim))
def __init__(self, input_dim, hidden_dim, output_dim, args, task='node'): super(GNNStack, self).__init__() self.input_dim = input_dim conv_model = self.build_conv_model(args.conv_type) self.convs = nn.ModuleList() self.convs.append(conv_model(input_dim, hidden_dim)) self.conv_type = args.conv_type assert (args.n_layers >= 1), 'Number of layers is not >=1' for l in range(args.n_layers - 1): self.convs.append(conv_model(hidden_dim, hidden_dim)) if self.conv_type == "gated": self.glob_soft_attn = pyg_nn.GlobalAttention( nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1))) # post-message-passing self.post_mp = nn.Sequential( nn.Linear(args.n_layers * hidden_dim, hidden_dim), nn.Dropout(args.dropout), nn.Linear(hidden_dim, output_dim)) self.task = task if not (self.task == 'node' or self.task == 'graph'): raise RuntimeError('Unknown task.') self.dropout = args.dropout self.num_layers = args.n_layers self.hidden_dim = hidden_dim
def __init__(self, feat_in, hidden_features=10, treesup_out=20, dropout=0.1): """ A class for whole tree assessment. Contains two convolution layers :param feat_in: Number of features of one node :param hidden_features: Number of features between layers """ super(WholeTreeAssessor, self).__init__() self.tree_conv = TreeSupport(feat_in, hidden_features=hidden_features, output_dim=treesup_out, dropout=dropout) self.gpl = gnn.GlobalAttention( nn.Sequential(nn.Linear(treesup_out, treesup_out), nn.Linear(treesup_out, 1)) ) self.lin = nn.Sequential(nn.Linear(treesup_out, 1), nn.ReLU())
def build_pool(self, in_channels): return geom_nn.GlobalAttention(nn.Linear(in_channels.node, 1))