def __init__(self,dim=64,edge_dim=12,node_in=8,edge_in=19,edge_in3=8,num_layers=1): super(Net_int_2Edges_Set2Set3, self).__init__() self.num_layers = num_layers self.dim = dim self.lin_node = torch.nn.Linear(node_in, dim) self.lin_edge_attr = torch.nn.Linear(edge_in, edge_dim) nn1 = Linear(edge_dim, dim * dim, bias=False) nn2 = Linear(edge_in3, dim * dim * 2 * 2, bias=False) self.conv1 = NNConv(dim, dim, nn1, aggr='mean', root_weight=False) self.gru1 = GRU(dim, dim) self.lin_covert = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \ RReLU(), Dropout(),Linear(dim*2, dim*2),RReLU()) self.conv2 = NNConv(dim*2, dim*2, nn2, aggr='mean', root_weight=False) self.gru2 = GRU(dim*2, dim*2) self.lin_weight = Linear(edge_in3, dim*2*2, bias=False) self.lin_bias = Linear(edge_in3, 1, bias=False) self.norm = BatchNorm1d(dim*2*2) self.norm_x = BatchNorm1d(node_in) self.head = Set2Set(dim*2,processing_steps=3,num_layers=num_layers) self.pool = Set2Set(dim*2,processing_steps=3) self.h_lin = Linear(edge_in3,num_layers*dim*2*2) self.q_star_lin = Linear(dim*2*2,dim*2*2)
def __init__(self,dim=64,edge_dim=12,node_in=8,edge_in=19,edge_in3=8): super(Net_int_2Edges_attention2, self).__init__() self.lin_node = torch.nn.Linear(node_in, dim) self.conv1 = GATConv(dim, dim, negative_slope=0.2, bias=True) self.lin_covert1 = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \ RReLU(), Dropout(),Linear(dim, dim),RReLU()) self.conv2 = GATConv(dim, dim, negative_slope=0.2, bias=True) self.lin_covert2 = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \ RReLU(), Dropout(),Linear(dim, dim),RReLU()) self.conv3 = GATConv(dim, dim, negative_slope=0.2, bias=True) self.lin_covert3 = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \ RReLU(), Dropout(),Linear(dim*2, dim),RReLU()) self.conv4 = GATConv(dim, dim, negative_slope=0.2, bias=True) self.lin_covert4 = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \ RReLU(), Dropout(),Linear(dim*2, dim),RReLU()) self.lin_weight = Linear(8, dim*3, bias=False) self.lin_bias = Linear(8, 1, bias=False) self.norm = BatchNorm1d(dim*3) self.norm_x = BatchNorm1d(node_in)
def __init__(self,dim=64,edge_dim=12,node_in=8,edge_in=19,edge_in3=8): super(Net_int_2Edges_intHead, self).__init__() self.lin_node = torch.nn.Linear(node_in, dim) self.lin_edge_attr = torch.nn.Linear(edge_in, edge_dim) nn1 = Linear(edge_dim, dim * dim, bias=False) nn2 = Linear(edge_in3, dim * dim * 2 * 2, bias=False) self.conv1 = NNConv(dim, dim, nn1, aggr='mean', root_weight=False) self.gru1 = GRU(dim, dim) self.lin_covert = Sequential(BatchNorm1d(dim),Linear(dim, dim*2), \ RReLU(), Dropout(),Linear(dim*2, dim*2),RReLU()) self.conv2 = NNConv(dim*2, dim*2, nn2, aggr='mean', root_weight=False) self.gru2 = GRU(dim*2, dim*2) self.head = InteractionNet(edge_in3,[dim*3*2,dim*2,1],[F.relu,None]) self.norm = BatchNorm1d(dim*3*2) self.norm_x = BatchNorm1d(node_in)