def __init__(self, num_features, n_classes, num_hidden, num_hidden_layers, dropout, activation, K=3, improved=True, bias=True): super(PTAG, self).__init__() # dropout if dropout: self.dropout = nn.Dropout(p=dropout) else: self.dropout = nn.Dropout(p=0.) #activation self.activation = activation # input layer self.conv_input = TAGConv(num_features, num_hidden, K=K, bias=bias) # Hidden layers self.layers = nn.ModuleList() for _ in range(num_hidden_layers): self.layers.append(TAGConv(num_hidden, num_hidden, K=K, bias=bias)) # output layer self.conv_output = TAGConv(num_hidden, n_classes, K=K, bias=bias)
def __init__(self, n_features, n_labels, classification=False, width=128, conv_depth=3, point_depth=3, lin_depth=5): super(ConvNet, self).__init__() self.classification = classification self.n_features = n_features self.n_labels = n_labels self.lin_depth = lin_depth self.conv_depth = conv_depth self.width = width n_intermediate = self.width n_intermediate2 = 2 * self.conv_depth * n_intermediate self.conv1 = TAGConv(self.n_features, n_intermediate, 2) self.convfkt = torch.nn.ModuleList([ TAGConv(n_intermediate, n_intermediate, 2) for i in range(self.conv_depth - 1) ]) ratio = .9 self.batchnorm1 = BatchNorm1d(n_intermediate2) self.linearfkt = torch.nn.ModuleList([ torch.nn.Linear(n_intermediate2, n_intermediate2) for i in range(self.lin_depth) ]) self.drop = torch.nn.ModuleList( [torch.nn.Dropout(.3) for i in range(self.lin_depth)]) self.out = torch.nn.Linear(n_intermediate2, self.n_labels) self.out2 = torch.nn.Linear(self.n_labels, self.n_labels)
def test_tag_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) conv = TAGConv(16, 32) assert conv.__repr__() == 'TAGConv(16, 32, K=3)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out1.tolist() assert jit(x, edge_index, value).tolist() == out2.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6)
def __init__(self, num_features): super(TAG, self).__init__() self.conv1 = TAGConv(num_features, 8) self.conv2 = TAGConv(8, 16) # self.fc = torch.nn.Linear(2 * 16, 1) self.fc = torch.nn.Linear(2 * 16, 2)
def __init__(self, num_feature, num_class, num_layers=2, hidden=64, drop=0.5, use_edge_weight=True): super(TAG_Linear, self).__init__() self.conv0 = TAGConv(num_feature, hidden,K=K) self.conv1 = TAGConv(hidden, hidden,K=K) self.n_layer = num_layers self.linear = Linear(hidden, num_class) self.use_edge_weight = use_edge_weight self.drop = drop
def __init__(self, num_features, num_layers): super(TAGWithJK, self).__init__() self.conv1 = TAGConv(num_features, 8) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(TAGConv(8, 8)) self.jump = JumpingKnowledge('cat') self.fc = torch.nn.Linear(2 * num_layers * 8, 2)
def test_tag_conv(): in_channels, out_channels = (16, 32) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) num_nodes = edge_index.max().item() + 1 edge_weight = torch.rand(edge_index.size(1)) x = torch.randn((num_nodes, in_channels)) conv = TAGConv(in_channels, out_channels) assert conv.__repr__() == 'TAGConv(16, 32, K=3)' assert conv(x, edge_index).size() == (num_nodes, out_channels) assert conv(x, edge_index, edge_weight).size() == (num_nodes, out_channels)
def __init__(self, num_features, channels=64): super(TAGLearn, self).__init__() self.conv1 = TAGConv(num_features, 8) self.conv2 = TAGConv(8, 8) self.fc = torch.nn.Linear(2 * 8, 2) # self.fc = torch.nn.Linear(8, 2) num_edges = channels * channels - channels self.edge_weight = torch.nn.Parameter(torch.FloatTensor(num_edges, 1), requires_grad=True) self.edge_weight.data.fill_(1)
def __init__(self, num_features, channels=64): super(TAGSortPool, self).__init__() self.k = 12 self.conv1 = TAGConv(num_features, 4) self.conv2 = TAGConv(4, 4) self.conv1d = torch.nn.Conv1d(4, 4, self.k) self.fc = torch.nn.Linear(4, 2) num_edges = channels * channels - channels self.edge_weight = torch.nn.Parameter(torch.FloatTensor(num_edges, 1), requires_grad=True) self.edge_weight.data.fill_(1)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.save_hyperparameters() self.convs = nn.ModuleList() self.convs.append( TAGConv(kwargs["num_features"], kwargs["hidden_channels"])) for _ in range(kwargs["num_layers"] - 2): self.convs.append( TAGConv(kwargs["hidden_channels"], kwargs["hidden_channels"])) self.convs.append( TAGConv(kwargs["hidden_channels"], kwargs["num_classes"]))
def __init__(self, n_features, K=3): super(Spectral, self).__init__() self.spec1 = TAGConv(n_features, 16, K=K) self.spec2 = TAGConv(16, 16, K=K) self.spec3 = TAGConv(16, 16, K=K) self.lin1 = Linear(16, 64) self.lin2 = Linear(64, 8) self.out = Linear(8, 1) self.s1 = SELU() self.s2 = SELU() self.s3 = SELU() self.s4 = SELU() self.s5 = SELU()
def __init__(self, num_features, num_layers, channels=64): super(TAGWithJKLearn, self).__init__() self.conv1 = TAGConv(num_features, 8) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(TAGConv(8, 8)) self.jump = JumpingKnowledge('cat') self.fc = torch.nn.Linear(2 * num_layers * 8, 2) num_edges = channels * channels - channels self.edge_weight = torch.nn.Parameter(torch.FloatTensor(num_edges, 1), requires_grad=True) self.edge_weight.data.fill_(1)
def __init__(self, in_feats, hidden_size, hidden_size1, num_classes, k): super(GCN, self).__init__() self.conv1 = TAGConv(in_feats, hidden_size, K=k) self.conv2 = TAGConv(hidden_size, hidden_size1, K=k) self.conv3 = TAGConv(hidden_size1, num_classes, K=k) x = 10 self.encoder = nn.Sequential( nn.Conv2d(1, x, (3, 3)), nn.LeakyReLU(), nn.Dropout2d(), nn.Conv2d(x, 2 * x, (3, 2)), nn.LeakyReLU(), nn.Dropout2d(), nn.Conv2d(2 * x, 1, (3, 2)), )
def __init__(self, k=2, preconv_ws=[32, 32, 32], highway_layers=3, highway_w=64, lin_ws=[32, 8]): super(Net, self).__init__() # Remove not active nodes self.conv = CustomGCN(2, 1, cached=False) self.conv.weight.requires_grad = False self.topk = CustomTopK(1, min_score=0.1) self.topk.weight.requires_grad = False prev_w = 2 # Pre Convolutions self.pre_convs = [] self.pre_bns = [] for i, w in enumerate(preconv_ws): setattr(self, "pre_conv{}".format(i), TAGConv(prev_w, w, k)) setattr(self, "pre_bn{}".format(i), BatchNorm1d(w)) self.pre_convs.append(getattr(self, "pre_conv{}".format(i))) self.pre_bns.append(getattr(self, "pre_bn{}".format(i))) prev_w = w # Highway Convolutions self.high_convs = [] self.high_bns = [] for i in range(highway_layers): setattr(self, "high_conv{}".format(i), TAGConv(prev_w, highway_w, k)) setattr(self, "high_bn{}".format(i), BatchNorm1d(highway_w)) self.high_convs.append(getattr(self, "high_conv{}".format(i))) self.high_bns.append(getattr(self, "high_bn{}".format(i))) prev_w = highway_w # MLP prev_w *= highway_layers self.lins = [] for i, w in enumerate(lin_ws): setattr(self, "lin{}".format(i), Linear(prev_w, w)) self.lins.append(getattr(self, "lin{}".format(i))) prev_w = w # Final Layer self.lin_final = Linear(prev_w, 3)
def test_static_tag_conv(): x = torch.randn(3, 4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) conv = TAGConv(16, 32) out = conv(x, edge_index) assert out.size() == (3, 4, 32)
def __init__(self, k=2, w1=128, w2=128, w3=128): super(Net, self).__init__() self.conv = CustomGCN(2, 1, cached=False) self.conv.weight.requires_grad = False self.topk = TopKPooling(1, min_score=0.1) self.topk.weight.requires_grad = False self.conv1 = TAGConv(2, w1, k) self.bn1 = BatchNorm1d(w1) self.conv2 = TAGConv(w1, w2, k) self.bn2 = BatchNorm1d(w2) self.conv3 = TAGConv(w2, w3, k) self.bn3 = BatchNorm1d(w3) self.linear = Linear(w3, 3)
def __init__(self, time_samples=128, channels=64, seq_len=8, input_size=4, hidden_size=4, num_layers=1): super(TAGLSTM, self).__init__() self.T = time_samples self.C = channels self.seq_len = seq_len self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.num_features = self.T // self.seq_len self.gcn_output_size = self.input_size self.gcn = TAGConv(self.num_features, self.gcn_output_size) self.gcns = torch.nn.ModuleList([self.gcn for i in range(self.seq_len)]) self.lstm = torch.nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True) self.fc = torch.nn.Linear(self.hidden_size, 2) num_edges = self.C * self.C - self.C self.weights = [] for i in range(self.seq_len): self.edge_weight = torch.nn.Parameter(torch.FloatTensor(num_edges, 1), requires_grad=True) self.edge_weight.data.fill_(1) self.weights.append(self.edge_weight)
def __init__(self, k=2, layers=3, graph_w=128, lin_ws=[32, 8]): super(Net, self).__init__() # Remove not active nodes self.conv = CustomGCN(2, 1, cached=False) self.conv.weight.requires_grad = False self.topk = CustomTopK(1, min_score=0.1) self.topk.weight.requires_grad = False # Convolutions prev_w = 2 self.convs = [] self.bns = [] for i in range(layers): setattr(self, "conv{}".format(i), TAGConv(prev_w, graph_w, k)) setattr(self, "bn{}".format(i), BatchNorm1d(graph_w)) self.convs.append(getattr(self, "conv{}".format(i))) self.bns.append(getattr(self, "bn{}".format(i))) prev_w = graph_w # MLP prev_w *= layers self.lins = [] for i, w in enumerate(lin_ws): setattr(self, "lin{}".format(i), Linear(prev_w, w)) self.lins.append(getattr(self, "lin{}".format(i))) prev_w = w # Final Layer self.lin_final = Linear(prev_w, 3)
def __init__(self, in_channels, out_channels, aggr_type, conv_type): super(GNNLayer, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.conv_type = conv_type if self.conv_type.startswith('gat'): heads = int(self.conv_type[4:]) self.conv = GATConv(in_channels, out_channels, heads=heads, concat=False) elif self.conv_type == 'gcn': self.conv = GCNConv(in_channels, out_channels) elif self.conv_type == 'sage': self.conv = SAGEConv(in_channels, out_channels) elif self.conv_type == 'cheb': self.conv = ChebConv(in_channels, out_channels, K=2) elif self.conv_type == 'tag': self.conv = TAGConv(in_channels, out_channels) elif self.conv_type == 'arma': self.conv = ARMAConv(in_channels, out_channels) elif self.conv_type == 'gin': self.conv = GINConv(nn.Sequential(nn.Linear(in_channels, out_channels), nn.ReLU(), nn.Linear(out_channels, out_channels))) elif self.conv_type == 'appnp': self.conv = LinearConv(in_channels, out_channels) self.conv.aggr = aggr_type
class TAG_Net(torch.nn.Module): def __init__(self, features_num, num_class, hidden, dropout): super(TAG_Net, self).__init__() self.dropout = dropout self.conv1 = TAGConv(features_num, hidden) self.conv2 = TAGConv(hidden, num_class) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)
def test_tag_conv(): in_channels, out_channels = (16, 32) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) edge_weight = torch.rand(edge_index.size(1)) num_nodes = edge_index.max().item() + 1 x = torch.randn((num_nodes, in_channels)) conv = TAGConv(in_channels, out_channels) assert conv.__repr__() == 'TAGConv(16, 32, K=3)' out1 = conv(x, edge_index) assert out1.size() == (num_nodes, out_channels) out2 = conv(x, edge_index, edge_weight) assert out2.size() == (num_nodes, out_channels) jit_conv = conv.jittable(x=x, edge_index=edge_index) jit_conv = torch.jit.script(jit_conv) assert jit_conv(x, edge_index).tolist() == out1.tolist() assert jit_conv(x, edge_index, edge_weight).tolist() == out2.tolist() conv = TAGConv(in_channels, out_channels, normalize=False) out = conv(x, edge_index, edge_weight) assert out.size() == (num_nodes, out_channels) jit_conv = conv.jittable(x=x, edge_index=edge_index, edge_weight=edge_weight) jit_conv = torch.jit.script(jit_conv) assert jit_conv(x, edge_index, edge_weight).tolist() == out.tolist()
def __init__(self, n_features, n_labels, classification=False, width=64, conv_depth=7, point_depth=1, lin_depth=7, aggr='max'): super(EnsembleNet, self).__init__() self.classification = classification self.n_features = n_features self.n_labels = n_labels self.lin_depth = lin_depth self.conv_depth = conv_depth self.width = width self.point_depth = point_depth self.aggr = aggr n_intermediate = self.width self.conv1 = TAGConv(self.n_features, n_intermediate, 2) self.convfkt = torch.nn.ModuleList([ TAGConv(n_intermediate, n_intermediate, 2) for i in range(self.conv_depth - 1) ]) self.point1 = EdgeConv( LNN([2 * n_features, n_intermediate, n_intermediate]), self.aggr) self.pointfkt = torch.nn.ModuleList([ EdgeConv(LNN([2 * n_intermediate, n_intermediate]), self.aggr) for i in range(self.point_depth - 1) ]) n_intermediate2 = 2 * self.conv_depth * n_intermediate + 2 * self.point_depth * n_intermediate self.dim2 = n_intermediate2 self.batchnorm1 = BatchNorm1d(n_intermediate2) self.linearfkt = torch.nn.ModuleList([ torch.nn.Linear(n_intermediate2, n_intermediate2) for i in range(self.lin_depth) ]) self.drop = torch.nn.ModuleList( [torch.nn.Dropout(.3) for i in range(self.lin_depth)]) self.out = torch.nn.Linear(n_intermediate2, self.n_labels) self.out2 = torch.nn.Linear(self.n_labels, self.n_labels)
def __init__(self, num_features, channels=64): super(TAGMerge, self).__init__() self.num_channels = channels self.conv_prior_1 = TAGConv(num_features, 4) self.conv_prior_2 = TAGConv(4, 8) self.conv_learn_1 = TAGConv(num_features, 4) self.conv_learn_2 = TAGConv(4, 8) self.fc = torch.nn.Linear(8, 2) num_edges = channels * channels - channels self.edge_weight_learn = torch.nn.Parameter(torch.FloatTensor( num_edges, 1), requires_grad=True) self.edge_weight_learn.data.fill_(1) self.edge_index_learn = self.gen_edges_cg(self.num_channels) self.edge_index_learn = torch.from_numpy(self.edge_index_learn).long()
def __init__(self, num_classes=36, k=0, device="cuda:0"): super(TactileSGNet, self).__init__() in_planes, out_planes = cfg_cnn[0] self.conv1 = TAGConv(in_planes, out_planes, K=3) self.fc1 = nn.Linear(cfg_s[-1] * cfg_cnn[-1][1], cfg_fc[0]) self.fc2 = nn.Linear(cfg_fc[0], cfg_fc[1]) self.fc3 = nn.Linear(cfg_fc[1], num_classes) self.num_classes = num_classes self.graph = TactileGraph(k) self.device = device
def __init__(self, n_features, n_labels, classification=False): super(ConvNet, self).__init__() self.classification = classification self.n_features = n_features self.n_labels = n_labels n_intermediate = 128 n_intermediate2 = 6 * n_intermediate self.conv1 = TAGConv(self.n_features, n_intermediate, 2) self.conv2 = TAGConv(n_intermediate, n_intermediate, 2) self.conv3 = TAGConv(n_intermediate, n_intermediate, 2) ratio = .9 self.batchnorm1 = BatchNorm1d(n_intermediate2) self.linear1 = torch.nn.Linear(n_intermediate2, n_intermediate2) self.linear2 = torch.nn.Linear(n_intermediate2, n_intermediate2) self.linear3 = torch.nn.Linear(n_intermediate2, n_intermediate2) self.linear4 = torch.nn.Linear(n_intermediate2, n_intermediate2) self.linear5 = torch.nn.Linear(n_intermediate2, n_intermediate2) self.drop = torch.nn.Dropout(.3) self.out = torch.nn.Linear(n_intermediate2, self.n_labels) self.out2 = torch.nn.Linear(self.n_labels, self.n_labels)
def init_model(self, n_class, feature_num): num_layers = int(self.hyperparameters['num_layers']) hidden_size = int(2**self.hyperparameters['hidden']) lr = self.hyperparameters['lr'] K = int(self.hyperparameters['K']) if self.hyperparameters['use_linear']: self.input_lin = Linear(feature_num, hidden_size) self.convs = torch.nn.ModuleList() for i in range(num_layers): self.convs.append( TAGConv(in_channels=hidden_size, out_channels=hidden_size, K=K)) self.output_lin = Linear(hidden_size, n_class) else: if num_layers == 1: self.conv1 = TAGConv(in_channels=feature_num, out_channels=n_class, K=K) else: self.conv1 = TAGConv(in_channels=feature_num, out_channels=hidden_size, K=K) self.convs = torch.nn.ModuleList() for i in range(num_layers - 2): self.convs.append( TAGConv(in_channels=hidden_size, out_channels=hidden_size, K=K)) self.conv2 = TAGConv(in_channels=hidden_size, out_channels=n_class, K=K) self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=5e-4) self = self.to('cuda') torch.cuda.empty_cache()
class TAGNet(nn.Module): def __init__(self, num_feature, num_class, num_layers=2, k=3, hidden=64, drop=0.5, use_edge_weight=True): super(TAGNet, self).__init__() self.conv0 = TAGConv(num_feature, hidden, K=k) self.conv1 = TAGConv(hidden, hidden, K=k) self.conv2 = TAGConv(hidden, num_class, K=k) self.n_layer = num_layers self.use_edge_weight = use_edge_weight self.drop = drop def reset_parameters(self): self.conv0.reset_parameters() self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data): x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr.squeeze(1) for i in range(self.n_layer - 1): conv = self.conv0 if i == 0 else self.conv1 x = conv(x, edge_index, edge_weight) if self.use_edge_weight else \ conv(x, edge_index) x = F.relu(x) x = F.dropout(x, p=self.drop, training=self.training) x = self.conv2(x, edge_index, edge_weight) if self.use_edge_weight else \ self.conv2(x, edge_index) return F.log_softmax(x, dim=1)
def __init__(self,categories_nums, features_num=16, num_class=2, sparse=False, degree_mean=2): super(TAGC, self).__init__() hidden = 32 embed_size = 8 dropout = 0.1 self.dropout_p = dropout id_embed_size = 16 self.id_embedding = Embedding(categories_nums[0], id_embed_size) self.lin0_id_emb = Linear(id_embed_size, id_embed_size) self.embeddings = torch.nn.ModuleList() for max_nums in categories_nums[1:]: self.embeddings.append(Embedding(max_nums, embed_size)) n = max(0,len(categories_nums)-1) if n>0: self.lin0_emb = Linear(embed_size*n, embed_size*n) if sparse: if features_num == 0: K= max(7,int(np.exp(-(degree_mean-1)/1.5)*100)) else: K=6 else: K=3 LOGGER.info(f'K values:{K}') if features_num>0: self.lin0 = Linear(features_num, hidden) self.ln0 = torch.nn.LayerNorm(id_embed_size+embed_size*n+hidden) self.conv1 = TAGConv(id_embed_size+embed_size*n+hidden, hidden,K=K) else: self.ln0 = torch.nn.LayerNorm(id_embed_size+embed_size*n) self.conv1 = TAGConv(id_embed_size+embed_size*n, hidden,K=K) self.ln1 = torch.nn.LayerNorm(hidden) self.lin1 = Linear(hidden, num_class)
def __init__(self, num_nodes, embed_dim, gnn_in_dim, gnn_hidden_dim, gnn_out_dim, gnn_num_layers, mlp_in_dim, mlp_hidden_dim, mlp_out_dim=1, mlp_num_layers=2, dropout=0.5, gnn_batchnorm=False, mlp_batchnorm=False, K=2, jk_mode='max'): super(DEA_GNN_JK, self).__init__() assert jk_mode in ['max','sum','mean','lstm','cat'] # Embedding self.emb = torch.nn.Embedding(num_nodes, embedding_dim=embed_dim) # GNN convs_list = [TAGConv(gnn_in_dim, gnn_hidden_dim, K)] for i in range(gnn_num_layers-2): convs_list.append(TAGConv(gnn_hidden_dim, gnn_hidden_dim, K)) convs_list.append(TAGConv(gnn_hidden_dim, gnn_out_dim, K)) self.convs = torch.nn.ModuleList(convs_list) # MLP lins_list = [torch.nn.Linear(mlp_in_dim, mlp_hidden_dim)] for i in range(mlp_num_layers-2): lins_list.append(torch.nn.Linear(mlp_hidden_dim, mlp_hidden_dim)) lins_list.append(torch.nn.Linear(mlp_hidden_dim, mlp_out_dim)) self.lins = torch.nn.ModuleList(lins_list) # Batchnorm self.gnn_batchnorm = gnn_batchnorm self.mlp_batchnorm = mlp_batchnorm if self.gnn_batchnorm: self.gnn_bns = torch.nn.ModuleList([torch.nn.BatchNorm1d(gnn_hidden_dim) for i in range(gnn_num_layers)]) if self.mlp_batchnorm: self.mlp_bns = torch.nn.ModuleList([torch.nn.BatchNorm1d(mlp_hidden_dim) for i in range(mlp_num_layers-1)]) self.jk_mode = jk_mode if self.jk_mode in ['max', 'lstm', 'cat']: self.jk = JumpingKnowledge(mode=self.jk_mode, channels=gnn_hidden_dim, num_layers=gnn_num_layers) self.dropout = dropout self.loss_fn = torch.nn.BCEWithLogitsLoss() self.reset_parameters()
def __init__(self, num_feature, num_class, num_layers=2, k=3, hidden=64, drop=0.5, use_edge_weight=True): super(TAGNet, self).__init__() self.conv0 = TAGConv(num_feature, hidden, K=k) self.conv1 = TAGConv(hidden, hidden, K=k) self.conv2 = TAGConv(hidden, num_class, K=k) self.n_layer = num_layers self.use_edge_weight = use_edge_weight self.drop = drop