예제 #1
0
    def forward(self, g, h, e):

        h_in = h  # for residual connection
        e_in = e  # for residual connection

        g.ndata['h'] = h
        g.ndata['Ah'] = self.A(h)
        g.ndata['Bh'] = self.B(h)
        g.ndata['Dh'] = self.D(h)
        g.ndata['Eh'] = self.E(h)
        g.edata['e'] = e
        g.edata['Ce'] = self.C(e)

        g.update_all(self.message_func, self.reduce_func)

        h = g.ndata['h']  # result of graph convolution
        e = g.edata['e']  # result of graph convolution

        if self.norm is not None:
            normalize(self.bn_node_h, h, g)
            normalize(self.bn_node_e, e, g)

        h = F.relu(h)  # non-linear activation
        e = F.relu(e)  # non-linear activation

        if self.residual:
            h = h_in + h  # residual connection
            e = e_in + e  # residual connection

        h = F.dropout(h, self.dropout, training=self.training)
        e = F.dropout(e, self.dropout, training=self.training)

        return h, e
예제 #2
0
    def forward(self, g, h):
        h_in = h  # for residual connection

        h = self.gatconv(g, h).flatten(1)

        if self.norm is not None:
            normalize(self.batchnorm_h, h, g)

        if self.activation:
            h = self.activation(h)

        if self.residual:
            h = h_in + h  # residual connection

        return h
예제 #3
0
    def forward(self, g, feature):
        h_in = feature  # to be used for residual connection

        if self.dgl_builtin == False:
            g.ndata['h'] = feature
            g.update_all(msg, reduce)
            g.apply_nodes(func=self.apply_mod)
            h = g.ndata['h']  # result of graph convolution
        else:
            h = self.conv(g, feature)

        if self.norm is not None:
            normalize(self.batchnorm_h, h, g)

        if self.activation:
            h = self.activation(h)

        if self.residual:
            h = h_in + h  # residual connection

        h = self.dropout(h)
        return h
예제 #4
0
    def forward(self, g, h):
        z = self.fc(h)
        g.ndata['z'] = z
        g.update_all(self.message_func, self.reduce_func)
        h = g.ndata['h']

        if self.norm is not None:
            h = normalize(self.batchnorm_h, h, g)

        h = F.elu(h)

        h = F.dropout(h, self.dropout, training=self.training)

        return h
예제 #5
0
    def forward(self, g, h, e):
        z_h = self.fc_h(h)
        z_e = self.fc_e(e)
        g.ndata['z_h'] = z_h
        g.edata['z_e'] = z_e

        g.apply_edges(self.edge_attention)

        g.update_all(self.message_func, self.reduce_func)

        h = g.ndata['h']
        e = g.edata['e_proj']

        if self.norm is not None:
            h = normalize(self.batchnorm_h, h, g)
            e = normalize(self.batchnorm_e, e, g)

        h = F.elu(h)
        e = F.elu(e)

        h = F.dropout(h, self.dropout, training=self.training)
        e = F.dropout(e, self.dropout, training=self.training)

        return h, e