def __init__(self, in_feat, out_feat, rel_names, num_bases, *, weight=True, bias=True, activation=None, self_loop=False, dropout=0.0): super(RelGraphConvLayer, self).__init__() self.in_feat = in_feat self.out_feat = out_feat self.rel_names = rel_names self.num_bases = num_bases self.bias = bias self.activation = activation self.self_loop = self_loop self.conv = dglnn.HeteroGraphConv({ rel: dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False) for rel in rel_names }) self.use_weight = weight self.use_basis = num_bases < len(self.rel_names) and weight if self.use_weight: if self.use_basis: self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names)) else: self.weight = nn.Parameter( torch.Tensor(len(self.rel_names), in_feat, out_feat)) nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) # bias if bias: self.h_bias = nn.Parameter(torch.Tensor(out_feat)) nn.init.zeros_(self.h_bias) # weight for self loop if self.self_loop: self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat)) nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu')) self.dropout = nn.Dropout(dropout)
def __init__( self, in_feats: int, out_feats: int, rel_names: List[str], num_bases: int, norm: str = 'right', weight: bool = True, bias: bool = True, activation: Callable[[torch.Tensor], torch.Tensor] = None, dropout: float = None, self_loop: bool = False, ): super().__init__() self._rel_names = rel_names self._num_rels = len(rel_names) self._conv = dglnn.HeteroGraphConv({ rel: dglnn.GraphConv(in_feats, out_feats, norm=norm, weight=False, bias=False) for rel in rel_names }) self._use_weight = weight self._use_basis = num_bases < self._num_rels and weight self._use_bias = bias self._activation = activation self._dropout = nn.Dropout(dropout) if dropout is not None else None self._use_self_loop = self_loop if weight: if self._use_basis: self.basis = dglnn.WeightBasis((in_feats, out_feats), num_bases, self._num_rels) else: self.weight = nn.Parameter( torch.Tensor(self._num_rels, in_feats, out_feats)) nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) if bias: self.bias = nn.Parameter(torch.Tensor(out_feats)) nn.init.zeros_(self.bias) if self_loop: self.self_loop_weight = nn.Parameter( torch.Tensor(in_feats, out_feats)) nn.init.xavier_uniform_(self.self_loop_weight, gain=nn.init.calculate_gain('relu'))