def __init__(self, in_feat, out_feat, rel_names, num_bases, *, weight=True, bias=True, activation=None, self_loop=False, dropout=0.0): super(RelGraphConvLayer, self).__init__() self.in_feat = in_feat self.out_feat = out_feat self.rel_names = rel_names self.num_bases = num_bases self.bias = bias self.activation = activation self.self_loop = self_loop self.batchnorm = False self.conv = dglnn.HeteroGraphConv({ rel: dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False) for rel in rel_names }) self.use_weight = weight self.use_basis = num_bases < len(self.rel_names) and weight if self.use_weight: if self.use_basis: self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names)) else: self.weight = nn.Parameter( th.Tensor(len(self.rel_names), in_feat, out_feat)) nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) # bias if bias: self.h_bias = nn.Parameter(th.Tensor(out_feat)) nn.init.zeros_(self.h_bias) # weight for self loop if self.self_loop: self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu')) # define batch norm layer if self.batchnorm: self.bn = nn.BatchNorm1d(out_feat) self.dropout = nn.Dropout(dropout)
def __init__(self, in_feat, out_feat, rel_names, ntype_names, num_bases, *, weight=True, bias=True, activation=None, self_loop=False, dropout=0.0, use_gcn_checkpoint=False, **kwargs): super(RelGraphConvLayer, self).__init__() self.in_feat = in_feat self.out_feat = out_feat self.rel_names = rel_names self.ntype_names = ntype_names self.num_bases = num_bases self.bias = bias self.activation = activation self.self_loop = self_loop self.use_gcn_checkpoint = use_gcn_checkpoint self.use_weight = weight self.use_basis = num_bases < len(self.rel_names) and weight if self.use_weight: if self.use_basis: self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names)) else: self.weight = nn.Parameter( th.Tensor(len(self.rel_names), in_feat, out_feat)) # nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_normal_(self.weight) # TODO # think of possibility switching to GAT # rel : dglnn.GATConv(in_feat, out_feat, num_heads=4) # rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False, allow_zero_in_degree=True) self.create_conv(in_feat, out_feat, rel_names) # bias if bias: self.bias_dict = nn.ParameterDict() for ntype_name in self.ntype_names: self.bias_dict[ntype_name] = nn.Parameter( torch.Tensor(1, out_feat)) nn.init.normal_(self.bias_dict[ntype_name]) # self.h_bias = nn.Parameter(th.Tensor(1, out_feat)) # nn.init.normal_(self.h_bias) # weight for self loop if self.self_loop: # if self.use_basis: # self.loop_weight_basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.ntype_names)) # else: # self.loop_weight = nn.Parameter(th.Tensor(len(self.ntype_names), in_feat, out_feat)) # # nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) # nn.init.xavier_normal_(self.loop_weight) self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('tanh')) # # nn.init.xavier_normal_(self.loop_weight) self.dropout = nn.Dropout(dropout)