def __init__(self, args): super(GCN, self).__init__() self.num_nodes = args.num_nodes self.conv1 = GCNConv(args.num_features, 16) self.conv2 = GCNConv(16, args.num_classes) self.device = args.device self.args = args
def __init__(self, dataset, hidden, num_feat_layers=1, num_conv_layers=3, num_fc_layers=2, gfn=False, collapse=False, residual=False, res_branch="BNConvReLU", global_pool="sum", dropout=0, edge_norm=True): super(ResGCN, self).__init__() assert num_feat_layers == 1, "more feat layers are not now supported" self.conv_residual = residual self.fc_residual = False # no skip-connections for fc layers. self.res_branch = res_branch self.collapse = collapse assert "sum" in global_pool or "mean" in global_pool, global_pool if "sum" in global_pool: self.global_pool = global_add_pool else: self.global_pool = global_mean_pool self.dropout = dropout GConv = partial(GCNConv, edge_norm=edge_norm, gfn=gfn) if "xg" in dataset[0]: # Utilize graph level features. self.use_xg = True self.bn1_xg = BatchNorm1d(dataset[0].xg.size(1)) self.lin1_xg = Linear(dataset[0].xg.size(1), hidden) self.bn2_xg = BatchNorm1d(hidden) self.lin2_xg = Linear(hidden, hidden) else: self.use_xg = False hidden_in = dataset.num_features if collapse: self.bn_feat = BatchNorm1d(hidden_in) self.bns_fc = torch.nn.ModuleList() self.lins = torch.nn.ModuleList() if "gating" in global_pool: self.gating = torch.nn.Sequential(Linear(hidden_in, hidden_in), torch.nn.ReLU(), Linear(hidden_in, 1), torch.nn.Sigmoid()) else: self.gating = None for i in range(num_fc_layers - 1): self.bns_fc.append(BatchNorm1d(hidden_in)) self.lins.append(Linear(hidden_in, hidden)) hidden_in = hidden self.lin_class = Linear(hidden_in, dataset.num_classes) else: self.bn_feat = BatchNorm1d(hidden_in) feat_gfn = True # set true so GCNConv is feat transform self.conv_feat = GCNConv(hidden_in, hidden, gfn=feat_gfn) if "gating" in global_pool: self.gating = torch.nn.Sequential(Linear(hidden, hidden), torch.nn.ReLU(), Linear(hidden, 1), torch.nn.Sigmoid()) else: self.gating = None self.bns_conv = torch.nn.ModuleList() self.convs = torch.nn.ModuleList() if self.res_branch == "resnet": for i in range(num_conv_layers): self.bns_conv.append(BatchNorm1d(hidden)) self.convs.append(GCNConv(hidden, hidden, gfn=feat_gfn)) self.bns_conv.append(BatchNorm1d(hidden)) self.convs.append(GConv(hidden, hidden)) self.bns_conv.append(BatchNorm1d(hidden)) self.convs.append(GCNConv(hidden, hidden, gfn=feat_gfn)) else: for i in range(num_conv_layers): self.bns_conv.append(BatchNorm1d(hidden)) self.convs.append(GConv(hidden, hidden)) self.bn_hidden = BatchNorm1d(hidden) self.bns_fc = torch.nn.ModuleList() self.lins = torch.nn.ModuleList() for i in range(num_fc_layers - 1): self.bns_fc.append(BatchNorm1d(hidden)) self.lins.append(Linear(hidden, hidden)) self.lin_class = Linear(hidden, dataset.num_classes) # BN initialization. for m in self.modules(): if isinstance(m, (torch.nn.BatchNorm1d)): torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.bias, 0.0001) self.proj_head1 = nn.Sequential(nn.Linear(128, 128), nn.ReLU(inplace=True), nn.Linear(128, 128)) self.proj_head2 = nn.Sequential(nn.Linear(128, 128), nn.ReLU(inplace=True), nn.Linear(128, 128))