Ejemplo n.º 1
0
    def __init__(self,
                 in_feat,
                 out_feat,
                 rel_names,
                 num_bases,
                 *,
                 weight=True,
                 bias=True,
                 activation=None,
                 self_loop=False,
                 dropout=0.0):
        super(RelGraphConvLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.rel_names = rel_names
        self.num_bases = num_bases
        self.bias = bias
        self.activation = activation
        self.self_loop = self_loop
        self.batchnorm = False

        self.conv = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feat,
                                 out_feat,
                                 norm='right',
                                 weight=False,
                                 bias=False)
            for rel in rel_names
        })

        self.use_weight = weight
        self.use_basis = num_bases < len(self.rel_names) and weight
        if self.use_weight:
            if self.use_basis:
                self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases,
                                               len(self.rel_names))
            else:
                self.weight = nn.Parameter(
                    th.Tensor(len(self.rel_names), in_feat, out_feat))
                nn.init.xavier_uniform_(self.weight,
                                        gain=nn.init.calculate_gain('relu'))

        # bias
        if bias:
            self.h_bias = nn.Parameter(th.Tensor(out_feat))
            nn.init.zeros_(self.h_bias)

        # weight for self loop
        if self.self_loop:
            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
            nn.init.xavier_uniform_(self.loop_weight,
                                    gain=nn.init.calculate_gain('relu'))
        # define batch norm layer
        if self.batchnorm:
            self.bn = nn.BatchNorm1d(out_feat)

        self.dropout = nn.Dropout(dropout)
    def __init__(self,
                 in_feat,
                 out_feat,
                 rel_names,
                 ntype_names,
                 num_bases,
                 *,
                 weight=True,
                 bias=True,
                 activation=None,
                 self_loop=False,
                 dropout=0.0,
                 use_gcn_checkpoint=False,
                 **kwargs):
        super(RelGraphConvLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.rel_names = rel_names
        self.ntype_names = ntype_names
        self.num_bases = num_bases
        self.bias = bias
        self.activation = activation
        self.self_loop = self_loop
        self.use_gcn_checkpoint = use_gcn_checkpoint

        self.use_weight = weight
        self.use_basis = num_bases < len(self.rel_names) and weight
        if self.use_weight:
            if self.use_basis:
                self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases,
                                               len(self.rel_names))
            else:
                self.weight = nn.Parameter(
                    th.Tensor(len(self.rel_names), in_feat, out_feat))
                # nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
                nn.init.xavier_normal_(self.weight)

        # TODO
        # think of possibility switching to GAT
        # rel : dglnn.GATConv(in_feat, out_feat, num_heads=4)
        # rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False, allow_zero_in_degree=True)
        self.create_conv(in_feat, out_feat, rel_names)

        # bias
        if bias:
            self.bias_dict = nn.ParameterDict()
            for ntype_name in self.ntype_names:
                self.bias_dict[ntype_name] = nn.Parameter(
                    torch.Tensor(1, out_feat))
                nn.init.normal_(self.bias_dict[ntype_name])
            # self.h_bias = nn.Parameter(th.Tensor(1, out_feat))
            # nn.init.normal_(self.h_bias)

        # weight for self loop
        if self.self_loop:
            # if self.use_basis:
            #     self.loop_weight_basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.ntype_names))
            # else:
            #     self.loop_weight = nn.Parameter(th.Tensor(len(self.ntype_names), in_feat, out_feat))
            #     # nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
            #     nn.init.xavier_normal_(self.loop_weight)
            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
            nn.init.xavier_uniform_(self.loop_weight,
                                    gain=nn.init.calculate_gain('tanh'))
            # # nn.init.xavier_normal_(self.loop_weight)

        self.dropout = nn.Dropout(dropout)