Esempio n. 1
0
 def __init__(self,dim=64,edge_dim=12,node_in=8,edge_in=19,edge_in3=8,num_layers=1):
     super(Net_int_2Edges_Set2Set3, self).__init__()
     self.num_layers = num_layers
     self.dim = dim
     self.lin_node = torch.nn.Linear(node_in, dim)
     self.lin_edge_attr = torch.nn.Linear(edge_in, edge_dim)
     
     nn1 = Linear(edge_dim, dim * dim, bias=False)
     nn2 = Linear(edge_in3, dim * dim * 2 * 2, bias=False)
     
     self.conv1 = NNConv(dim, dim, nn1, aggr='mean', root_weight=False)
     self.gru1 = GRU(dim, dim)
     self.lin_covert = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \
                                  RReLU(), Dropout(),Linear(dim*2, dim*2),RReLU())
     
     self.conv2 = NNConv(dim*2, dim*2, nn2, aggr='mean', root_weight=False)
     self.gru2 = GRU(dim*2, dim*2)
     
     self.lin_weight = Linear(edge_in3, dim*2*2, bias=False)
     self.lin_bias = Linear(edge_in3, 1, bias=False)
     
     self.norm = BatchNorm1d(dim*2*2)
     self.norm_x = BatchNorm1d(node_in)
     self.head = Set2Set(dim*2,processing_steps=3,num_layers=num_layers)
     self.pool = Set2Set(dim*2,processing_steps=3)
     
     self.h_lin = Linear(edge_in3,num_layers*dim*2*2)
     self.q_star_lin = Linear(dim*2*2,dim*2*2)
Esempio n. 2
0
 def __init__(self,dim=64,edge_dim=12,node_in=8,edge_in=19,edge_in3=8):
     super(Net_int_2Edges_attention2, self).__init__()
     self.lin_node = torch.nn.Linear(node_in, dim)
     
     self.conv1 = GATConv(dim, dim, negative_slope=0.2, bias=True)
     
     self.lin_covert1 = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \
                                  RReLU(), Dropout(),Linear(dim, dim),RReLU())     
     
     self.conv2 = GATConv(dim, dim, negative_slope=0.2, bias=True)
     
     self.lin_covert2 = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \
                                  RReLU(), Dropout(),Linear(dim, dim),RReLU())     
     
     self.conv3 = GATConv(dim, dim, negative_slope=0.2, bias=True)
     
     self.lin_covert3 = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \
                          RReLU(), Dropout(),Linear(dim*2, dim),RReLU())
     
     self.conv4 = GATConv(dim, dim, negative_slope=0.2, bias=True)
     
     self.lin_covert4 = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \
                          RReLU(), Dropout(),Linear(dim*2, dim),RReLU())
     
     self.lin_weight = Linear(8, dim*3, bias=False)
     self.lin_bias = Linear(8, 1, bias=False)
     self.norm = BatchNorm1d(dim*3)
     self.norm_x = BatchNorm1d(node_in)
Esempio n. 3
0
 def __init__(self,dim=64,edge_dim=12,node_in=8,edge_in=19,edge_in3=8):
     super(Net_int_2Edges_intHead, self).__init__()
     self.lin_node = torch.nn.Linear(node_in, dim)
     self.lin_edge_attr = torch.nn.Linear(edge_in, edge_dim)
     
     nn1 = Linear(edge_dim, dim * dim, bias=False)
     nn2 = Linear(edge_in3, dim * dim * 2 * 2, bias=False)
     
     self.conv1 = NNConv(dim, dim, nn1, aggr='mean', root_weight=False)
     self.gru1 = GRU(dim, dim)
     self.lin_covert = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \
                                  RReLU(), Dropout(),Linear(dim*2, dim*2),RReLU())
     
     self.conv2 = NNConv(dim*2, dim*2, nn2, aggr='mean', root_weight=False)
     self.gru2 = GRU(dim*2, dim*2)
     
     self.head = InteractionNet(edge_in3,[dim*3*2,dim*2,1],[F.relu,None])
     
     self.norm = BatchNorm1d(dim*3*2)
     self.norm_x = BatchNorm1d(node_in)