def forward(self, g, x):
     g = g.local_var()
     g.nodes['n'].data['h'] = x
     for i in g.etypes:
         one_hot = (th.ones((g.number_of_edges(etype=i)), dtype=th.long) *
                    self.rel_dict[i])
         if self.use_cuda:
             g.edges[i].data['type'] = one_hot.cuda()
         else:
             g.edges[i].data['type'] = one_hot
     if self.self_loop:
         loop_message = utils.matmul_maybe_select(x, self.loop_weight)
     # message passing
     for i in g.etypes:
         g.update_all(self.message_func,
                      fn.sum(msg='msg', out='h'),
                      etype=i)
     # apply bias and activation
     node_repr = g.nodes['n'].data['h']
     if self.bias:
         node_repr = node_repr + self.h_bias
     if self.self_loop:
         node_repr = node_repr + loop_message
     if self.activation:
         node_repr = self.activation(node_repr)
     node_repr = self.dropout(node_repr)
     return node_repr
示例#2
0
 def forward(self, g, x, etypes, norm=None):
     g = g.local_var()
     g.ndata['h'] = x
     g.edata['type'] = etypes
     if norm is not None:
         g.edata['norm'] = norm
     if self.self_loop:
         loop_message = utils.matmul_maybe_select(x, self.loop_weight)
     # message passing
     g.update_all(self.message_func, self.reduce_func(msg='msg', out='h'))
     # apply bias and activation
     node_repr = g.ndata['h']
     if self.bias:
         node_repr = node_repr + self.h_bias
     if self.self_loop:
         node_repr = node_repr + loop_message
     if self.activation:
         node_repr = self.activation(node_repr)
     node_repr = self.dropout(node_repr)
     return node_repr
    def forward(self, g):
        assert g.is_homograph(), \
            "not a homograph; convert it with to_homo and pass in the edge type as argument"
        with g.local_scope():
            if self.self_loop:
                loop_message = dgl_utils.matmul_maybe_select(
                    g.ndata['h'], self.loop_weight)

            # message passing
            g.update_all(self.message_func, fn.sum(msg='msg', out='h'))

            # apply bias and activation
            node_repr = g.ndata['h']
            if self.bias:
                node_repr = node_repr + self.h_bias
            if self.self_loop:
                node_repr = node_repr + loop_message
            if self.activation:
                node_repr = self.activation(node_repr)
            node_repr = self.dropout(node_repr)
            return node_repr