예제 #1
0
    def forward(self, g, h, h_en):
        """Forward computation

        """
        with g.local_scope():
            h_src, h_dst = expand_as_pair(h)
            h_src_en, h_dst_en = expand_as_pair(h_en)

            g.srcdata['x'] = h_src
            g.dstdata['x'] = h_dst

            g.srcdata['en'] = h_src_en
            g.dstdata['en'] = h_dst_en

            if not self.batch_norm:
                #g.update_all(self.message, fn.mean('e', 'x'))
                g.apply_edges(self.message)
                g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x'))
                g.update_all(fn.copy_e('e_en', 'e_en'), fn.mean('e_en', 'en'))
            else:
                g.apply_edges(self.message)

                g.edata['e'] = self.bn(g.edata['e'])

                g.update_all(fn.copy_e('e', 'e'), fn.max('e', 'x'))

                g.update_all(fn.copy_e('e_en', 'e_en'), fn.mean('e_en', 'en'))

            return g.dstdata['x'], g.dstdata['en']  #+  h_en
예제 #2
0
 def _pull_nodes(nodes):
     # compute ground truth
     g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)
     o1 = g.ndata.pop('o1')
     g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)
     o2 = g.ndata.pop('o2')
     g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)
     o3 = g.ndata.pop('o3')
     # v2v spmv
     g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'),
                  fn.sum(msg='m1', out='o1'),
                  _afunc)
     assert U.allclose(o1, g.ndata.pop('o1'))
     # v2v fallback to e2v
     g.pull(nodes, fn.src_mul_edge(src='h', edge='w2', out='m2'),
                  fn.sum(msg='m2', out='o2'),
                  _afunc)
     assert U.allclose(o2, g.ndata.pop('o2'))
     # v2v fallback to degree bucketing
     g.pull(nodes, fn.src_mul_edge(src='h', edge='w1', out='m1'),
                  fn.max(msg='m1', out='o3'),
                  _afunc)
     assert U.allclose(o3, g.ndata.pop('o3'))
     # multi builtins, both v2v spmv
     g.pull(nodes,
            [fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w1', out='m2')],
            [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
            _afunc)
     assert U.allclose(o1, g.ndata.pop('o1'))
     assert U.allclose(o1, g.ndata.pop('o2'))
     # multi builtins, one v2v spmv, one fallback to e2v
     g.pull(nodes,
            [fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2')],
            [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
            _afunc)
     assert U.allclose(o1, g.ndata.pop('o1'))
     assert U.allclose(o2, g.ndata.pop('o2'))
     # multi builtins, one v2v spmv, one fallback to e2v, one fallback to degree-bucketing
     g.pull(nodes,
            [fn.src_mul_edge(src='h', edge='w1', out='m1'),
             fn.src_mul_edge(src='h', edge='w2', out='m2'),
             fn.src_mul_edge(src='h', edge='w1', out='m3')],
            [fn.sum(msg='m1', out='o1'),
             fn.sum(msg='m2', out='o2'),
             fn.max(msg='m3', out='o3')],
            _afunc)
     assert U.allclose(o1, g.ndata.pop('o1'))
     assert U.allclose(o2, g.ndata.pop('o2'))
     assert U.allclose(o3, g.ndata.pop('o3'))
예제 #3
0
    def forward(self, graph, feat, e_feat):
        r"""Compute GraphSAGE layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor
            The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
            is size of input feature, :math:`N` is the number of nodes.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
            is size of output feature.
        """
        graph = graph.local_var()
        feat = self.feat_drop(feat)
        h_self = feat
        graph.edata['e'] = e_feat
        if self._aggre_type == 'sum':
            graph.ndata['h'] = feat
            graph.update_all(fn.u_mul_e('h', 'e', 'm'), fn.sum('m', 'neigh'))
            h_neigh = graph.ndata['neigh']
        elif self._aggre_type == 'mean':
            graph.ndata['h'] = feat
            graph.update_all(fn.u_mul_e('h', 'e', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.ndata['neigh']
        elif self._aggre_type == 'gcn':
            graph.ndata['h'] = feat
            graph.update_all(fn.u_mul_e('h', 'e', 'm'), fn.sum('m', 'neigh'))
            # divide in_degrees
            degs = graph.in_degrees().float()
            degs = degs.to(feat.device)
            h_neigh = (graph.ndata['neigh'] +
                       graph.ndata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'pool':
            graph.ndata['h'] = F.relu(self.fc_pool(feat))
            graph.update_all(fn.u_mul_e('h', 'e', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.ndata['neigh']
        elif self._aggre_type == 'lstm':
            graph.ndata['h'] = feat
            graph.update_all(fn.u_mul_e('h', 'e', 'm'), self._lstm_reducer)
            h_neigh = graph.ndata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(
                self._aggre_type))
        # GraphSAGE GCN does not require fc_self.
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        # activation
        if self.activation is not None:
            rst = self.activation(rst)
        # normalization
        if self.norm is not None:
            rst = self.norm(rst)
        return rst
    def forward(self, g, h, e):
        h_in = h  # for residual connection

        if self.dgl_builtin == False:
            h = self.dropout(h)
            g.ndata['h'] = h
            #g.update_all(fn.copy_src(src='h', out='m'),
            #             self.aggregator,
            #             self.nodeapply)
            if self.aggregator_type == 'maxpool':
                g.ndata['h'] = self.aggregator.linear(g.ndata['h'])
                g.ndata['h'] = self.aggregator.activation(g.ndata['h'])
                g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'c'),
                             self.nodeapply)
            elif self.aggregator_type == 'lstm':
                g.update_all(fn.copy_src(src='h', out='m'), self.aggregator,
                             self.nodeapply)
            else:
                g.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'c'),
                             self.nodeapply)
            h = g.ndata['h']
        else:
            # For original graphs
            # h = self.sageconv(g, h)
            # For reduced graphs
            h = self.sageconv(g, h, edge_weight=e)

        if self.batch_norm:
            h = self.batchnorm_h(h)

        if self.residual:
            h = h_in + h  # residual connection

        return h
예제 #5
0
    def sample_frontier(self, block_id, g, seed_nodes):
        fanout = self.fanouts[block_id] if self.fanouts is not None else None
        # List of neighbors to sample per edge type for each GNN layer, starting from the first layer.
        g = dgl.in_subgraph(g, seed_nodes)
        g.remove_edges(torch.where(g.edata['timestamp'] > self.ts)[0])
        if self.args.valid_path:
            if block_id != self.args.n_layer - 1:
                g.dstdata['sample_time'] = self.frontiers[block_id + 1].srcdata['sample_time']
                g.apply_edges(self.sample_prob)
                g.remove_edges(torch.where(g.edata['timespan'] < 0)[0])
            g_re=dgl.reverse(g,copy_edata=True,copy_ndata=True)
            g_re.update_all(self.sample_time,fn.max('st','sample_time'))
            g=dgl.reverse(g_re,copy_edata=True,copy_ndata=True)

        if fanout is None:
            frontier = g
        else:
            if block_id == self.args.n_layer - 1:

                if self.args.bandit:
                    frontier = dgl.sampling.sample_neighbors(g,seed_nodes,fanout,prob='q_ij')
                else:
                    frontier = dgl.sampling.sample_neighbors(g, seed_nodes, fanout)

            else:
                frontier = dgl.sampling.sample_neighbors(g, seed_nodes, fanout)

        self.frontiers[block_id] = frontier
        return frontier
예제 #6
0
 def agg(self, x, B):
     h = x
     x = self.dropout(x)
     for i in range(self.K):
         if i == 0:
             if self.aggregator == 'pool':
                 x = torch.matmul(x, self.weight_pool_in)
                 if self.bias:
                     x = x + self.bias_in
             if self.aggregator == 'gcn':
                 B[i].srcdata['h'] = torch.matmul(x, self.weight_gcn_in)
             else:
                 B[i].srcdata['h'] = x
             B[i].dstdata['h'] = x[:B[i].number_of_dst_nodes()]
         else:
             if self.aggregator == 'pool':
                 hh = torch.matmul(B[i - 1].dstdata['h'], self.weight_pool_hid)
                 if self.bias:
                     hh = hh + self.bias_hid
             else:
                 hh = B[i - 1].dstdata['h']
             if self.aggregator == 'gcn':
                 B[i].srcdata['h'] = torch.matmul(hh, self.weight_gcn_hid)
             else:
                 B[i].srcdata['h'] = hh
             B[i].dstdata['h'] = hh[:B[i].number_of_dst_nodes()]
         if self.aggregator == 'gcn':
             B[i].update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
         elif self.aggregator == 'mean':
             B[i].update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
         elif self.aggregator == 'lstm':
             B[i].update_all(fn.copy_src('h', 'm'), self.lstm_reducer_in if i == 0 else self.lstm_reducer_hid)
         else:
             B[i].update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
         h_neigh = B[i].dstdata['neigh']
         if i == 0:
             h = torch.matmul(B[i].dstdata['h'], self.weight_in[0, :, :]) \
                 + (torch.matmul(h_neigh, self.weight_in[1, :, :]) if self.aggregator != 'gcn' else 0)
             if self.bias:
                 h = h + self.bias_in_k[0, :] + (self.bias_in_k[1, :] if self.aggregator != 'gcn' else 0)
         elif i == self.K - 1:
             h = torch.matmul(B[i].dstdata['h'], self.weight_out[0, :, :])\
                 + (torch.matmul(h_neigh, self.weight_out[1, :, :]) if self.aggregator != 'gcn' else 0)
             if self.bias:
                 h = h + self.bias_out_k[0, :] + (self.bias_out_k[1, :] if self.aggregator != 'gcn' else 0)
         else:
             h = torch.matmul(B[i].dstdata['h'], self.weight_hid[i - 1, 0, :, :])\
                 + (torch.matmul(h_neigh, self.weight_hid[i - 1, 1, :, :]) if self.aggregator != 'gcn' else 0)
             if self.bias:
                 h = h + self.bias_hid_k[0, :] + (self.bias_hid_k[1, :] if self.aggregator != 'gcn' else 0)
         if self.activation and i != self.K - 1:
             h = self.activation(h, inplace=False)
         if i != self.K - 1:
             h = self.dropout(h)
         if self.norm:
             norm = torch.norm(h, dim=1)
             norm = norm + (norm == 0).long()
             h = h / norm.unsqueeze(-1)
         B[i].dstdata['h'] = h
     return h
예제 #7
0
    def fit(self, train_labels, train_mask):
        """Trains the model.

        Parameters
        ----------
        train_labels: torch.LongTensor
            Tensor of target data of size n_train_nodes.

        train_mask: torch.ByteTensor
            Boolean mask of size n_nodes indicating the nodes used in training.
        """
        # Add initial node labels
        if train_labels.is_cuda:
            init_labels = torch.cuda.FloatTensor(self.graph.number_of_nodes()).fill_(0)
        else:
            init_labels = torch.zeros(self.graph.number_of_nodes(), dtype=torch.float)
        init_labels[train_mask] = train_labels.float()
        self.graph.ndata["l"] = init_labels

        # Propagate
        self.graph.update_all(
            message_func=fn.copy_src(src="l", out="m"),
            reduce_func=fn.max(msg="m", out="l"),
        )

        # Put back positive seed nodes
        self.graph.ndata["l"] = torch.max(self.graph.ndata["l"], init_labels)

        self.predictions = self.graph.ndata["l"]
예제 #8
0
    def collate(self, items):
        '''
        items: edge id in graph g.
        We sample iteratively k-times and batch them into one single subgraph.
        '''
        current_ts = self.g.edata['timestamp'][
            items[0]]  #only sample edges before current timestamp
        self.graph_sampler.ts = current_ts  # restore the current timestamp to the graph sampler.

        # if link prefiction, we use a negative_sampler to generate neg-graph for loss computing.
        if self.negative_sampler is None:
            neg_pair_graph = None
            input_nodes, pair_graph, blocks = self._collate(items)
        else:
            input_nodes, pair_graph, neg_pair_graph, blocks = self._collate_with_negative_sampling(
                items)

        # we sampling k-hop subgraph and batch them into one graph
        for i in range(self.n_layer - 1):
            self.graph_sampler.frontiers[0].add_edges(
                *self.graph_sampler.frontiers[i + 1].edges())
        frontier = self.graph_sampler.frontiers[0]
        # computing node last-update timestamp
        frontier.update_all(fn.copy_e('timestamp', 'ts'),
                            fn.max('ts', 'timestamp'))

        return input_nodes, pair_graph, neg_pair_graph, [frontier]
def track_time(graph_name, format, feat_size, msg_type, reduce_type):
    device = utils.get_bench_device()
    graph = utils.get_graph(graph_name, format)
    graph = graph.to(device)
    graph.ndata['h'] = torch.randn((graph.num_nodes(), feat_size),
                                   device=device)
    graph.edata['e'] = torch.randn((graph.num_edges(), 1), device=device)

    msg_builtin_dict = {
        'copy_u': fn.copy_u('h', 'x'),
        'u_mul_e': fn.u_mul_e('h', 'e', 'x'),
    }

    reduce_builtin_dict = {
        'sum': fn.sum('x', 'h_new'),
        'mean': fn.mean('x', 'h_new'),
        'max': fn.max('x', 'h_new'),
    }

    # dry run
    graph.update_all(msg_builtin_dict[msg_type],
                     reduce_builtin_dict[reduce_type])

    # timing

    with utils.Timer() as t:
        for i in range(3):
            graph.update_all(msg_builtin_dict[msg_type],
                             reduce_builtin_dict[reduce_type])

    return t.elapsed_secs / 3
예제 #10
0
    def forward(self, g, h):
        g = g.local_var()
        if not self.use_pp or not self.training:
            norm = self.get_norm(g)

            # g.ndata['h'] = h
            # g.update_all(fn.copy_src(src='h', out='m'),
            #              fn.sum(msg='m', out='h'))
            # ah = g.ndata.pop('h')

            if self._aggre_type == 'mean':
                g.ndata['h'] = h
                g.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'h'))
                ah = g.ndata.pop('h')
            elif self._aggre_type == 'gcn':
                g.ndata['h'] = h
                g.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'))
                # divide in_degrees
                # degs = graph.in_degrees().float()
                # degs = degs.to(feat.device)
                # h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.unsqueeze(-1) + 1)
                ah = g.ndata.pop('h')
                ah = ah * norm
            elif self._aggre_type == 'pool':
                g.ndata['h'] = F.relu(self.fc_pool(h))
                g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'h'))
                ah = g.ndata['h']
            elif self._aggre_type == 'lstm':
                g.ndata['h'] = h
                g.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
                ah = g.ndata['h']
            elif self._aggre_type == 'attn':
                feat = self.fc_attn(h).view(-1, self.num_heads, self._in_feats)
                el = (feat * self.attn_l).sum(dim=-1).unsqueeze(-1)
                er = (feat * self.attn_r).sum(dim=-1).unsqueeze(-1)
                g.ndata.update({'ft': feat, 'el': el, 'er': er})
                g.apply_edges(fn.u_add_v('el', 'er', 'e'))
                e = self.leaky_relu(g.edata.pop('e'))
                g.edata['a'] = edge_softmax(g, e)
                g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
                ah = g.ndata['ft']
                ah = ah.squeeze(1)
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(
                    self._aggre_type))

            h = self.concat(h, ah, norm)
        if self.dropout:
            h = self.dropout(h)
        # GraphSAGE GCN does not require fc_self.
        # if self._aggre_type == 'gcn':
        #     rst = self.fc_neigh(ah)
        # else:
        #     rst = self.fc_self(h) + self.fc_neigh(ah)
        h = self.linear(h)
        h = self.lynorm(h)
        if self.activation:
            h = self.activation(h)
        return h
예제 #11
0
def get_current_ts(pos_graph, neg_graph):
    with pos_graph.local_scope():
        pos_graph_ = dgl.add_reverse_edges(pos_graph, copy_edata=True)
        pos_graph_.update_all(fn.copy_e('timestamp', 'times'),
                              fn.max('times', 'ts'))
        current_ts = pos_ts = pos_graph_.ndata['ts']
        num_pos_nodes = pos_graph_.num_nodes()
    with neg_graph.local_scope():
        neg_graph_ = dgl.add_reverse_edges(neg_graph)
        neg_graph_.edata['timestamp'] = pos_graph_.edata['timestamp']
        neg_graph_.update_all(fn.copy_e('timestamp', 'times'),
                              fn.max('times', 'ts'))
        num_pos_nodes = torch.where(pos_graph_.ndata['ts'] > 0)[0].shape[0]
        pos_ts = pos_graph_.ndata['ts'][:num_pos_nodes]
        neg_ts = neg_graph_.ndata['ts'][num_pos_nodes:]
        current_ts = torch.cat([pos_ts, neg_ts])
    return current_ts, pos_ts, num_pos_nodes
예제 #12
0
 def __init__(self, pool_type):
     super(GraphPooling, self).__init__()
     self.pool_type = pool_type
     if pool_type == 'mean':
         self.reduce_func = fn.mean(msg='m', out='h')
     elif pool_type == 'max':
         self.reduce_func = fn.max(msg='m', out='h')
     elif pool_type == 'min':
         self.reduce_func = fn.min(msg='m', out='h')
예제 #13
0
    def pool_agg(self, g):
        x = g.ndata['x']
        x = self.dropout(x)
        h = torch.matmul(x, self.weight_pool_in)
        if self.bias:
            h = h + self.bias_in
        g.srcdata['h'] = h
        for i in range(self.K):
            if i == 0:
                g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = g.dstdata['neigh']
                h = torch.matmul(torch.cat([g.srcdata['h'], h_neigh], dim=1), self.weight_in)
                if self.activation:
                    h = self.activation(h, inplace=False)
                norm = torch.norm(h, dim=1)
                h = h / (norm.unsqueeze(-1) + 0.05)
                g.srcdata['h'] = h
            elif i == self.K - 1:
                h = torch.matmul(g.srcdata['h'], self.weight_pool_hid)
                if self.bias:
                    h = h + self.bias_hid
                g.srcdata['h'] = h
                g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = g.dstdata['neigh']
                h = torch.matmul(torch.cat([g.srcdata['h'], h_neigh], dim=1), self.weight_out)

                norm = torch.norm(h, dim=1)
                h = h / (norm.unsqueeze(-1) + 0.05)
                g.ndata['z'] = h
            else:
                h = torch.matmul(g.srcdata['h'], self.weight_pool_hid)
                if self.bias:
                    h = h + self.bias_hid
                g.srcdata['h'] = h
                g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = g.dstdata['neigh']
                h = torch.matmul(torch.cat([g.srcdata['h'], h_neigh], dim=1), self.weight_hid[i-1, :, :])
                if self.activation:
                    h = self.activation(h, inplace=False)
                norm = torch.norm(h, dim=1)
                h = h / (norm.unsqueeze(-1) + 0.05)
                g.srcdata['h'] = h
        return g
예제 #14
0
 def edge_softmax(self, g):
     # compute the max
     g.update_all(fn.copy_edge('a', 'a'), fn.max('a', 'a_max'))
     # minus the max and exp
     g.apply_edges(lambda edges:
                   {'a': torch.exp(edges.data['a'] - edges.dst['a_max'])})
     # compute dropout
     g.apply_edges(
         lambda edges: {'a_drop': self.attn_drop(edges.data['a'])})
     # compute normalizer
     g.update_all(fn.copy_edge('a', 'a'), fn.sum('a', 'z'))
예제 #15
0
    def _test(fld):
        def message_func(edges):
            return {'m': edges.src[fld]}

        def message_func_edge(edges):
            if len(edges.src[fld].shape) == 1:
                return {'m': edges.src[fld] * edges.data['e1']}
            else:
                return {'m': edges.src[fld] * edges.data['e2']}

        def reduce_func(nodes):
            return {fld: mx.nd.max(nodes.mailbox['m'], axis=1)}

        def apply_func(nodes):
            return {fld: 2 * nodes.data[fld]}

        g = simple_graph()
        # update all
        v1 = g.ndata[fld]
        g.update_all(fn.copy_src(src=fld, out='m'), fn.max(msg='m', out=fld),
                     apply_func)
        v2 = g.ndata[fld]
        g.set_n_repr({fld: v1})
        g.update_all(message_func, reduce_func, apply_func)
        v3 = g.ndata[fld]
        assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05)
        # update all with edge weights
        v1 = g.ndata[fld]
        g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
                     fn.max(msg='m', out=fld), apply_func)
        v2 = g.ndata[fld]
        g.set_n_repr({fld: v1})
        g.update_all(fn.src_mul_edge(src=fld, edge='e2', out='m'),
                     fn.max(msg='m', out=fld), apply_func)
        v3 = g.ndata[fld].squeeze()
        g.set_n_repr({fld: v1})
        g.update_all(message_func_edge, reduce_func, apply_func)
        v4 = g.ndata[fld]
        assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05)
        assert np.allclose(v3.asnumpy(), v4.asnumpy(), rtol=1e-05, atol=1e-05)
예제 #16
0
 def forward(self, graph, x):
     num_nan = 0
     for item in self.fc_msg:
         if isinstance(item, nn.Linear):
             num_nan += item.weight.isnan().sum()
     for item in self.fc_udt:
         if isinstance(item, nn.Linear):
             num_nan += item.weight.isnan().sum()
     if num_nan > 0:
         print("nan is found in model parameters.")
     graph.ndata['in_feats'] = x
     graph.update_all(self.message, fn.max('m', 'r'))
     return self.fc_udt(
         torch.cat([graph.dstdata['in_feats'], graph.dstdata['r']], dim=1))
예제 #17
0
 def forward(self, g, x):
     with g.local_scope():
         g.ndata['x'] = x
         g.ndata['z'] = self.gate_m(x)
         g.update_all(fn.copy_u('x', 'x'), fn.mean('x', 'mean_z'))
         g.update_all(fn.copy_u('z', 'z'), fn.max('z', 'max_z'))
         nft = torch.cat([g.ndata['x'], g.ndata['max_z'],
                          g.ndata['mean_z']], dim=1)
         gate = self.gate_fn(nft).sigmoid()
         attn_out = self.gatlayer(g, x)
         node_num = g.num_nodes()
         gated_out = ((gate.view(-1)*attn_out.view(-1, self.out_feats).T).T).view(
             node_num, self.num_heads, self.out_feats)
         gated_out = gated_out.mean(1)
         merge = self.merger_layer(torch.cat([x, gated_out], dim=1))
         return merge
 def forward(self, graph, feat):
     graph = graph.local_var()
     if isinstance(feat, tuple):
         feat_src, feat_dst = feat
     else:
         feat_src = feat_dst = feat
     h_self = feat_dst
     # DIN attention: 两个向量、两个向量的差、两个向量的积,分别mlp到n_hidden,再相加,再mlp到1
     ## 计算两个向量的差和积
     graph.srcdata.update({'e_src': feat_src})
     graph.dstdata.update({'e_dst': feat_dst})
     graph.apply_edges(fn.u_sub_v('e_src', 'e_dst', 'e_sub'))
     graph.apply_edges(fn.u_mul_v('e_src', 'e_dst', 'e_mul'))
     ## 分别mlp
     graph.srcdata["e_src"] = self.atten_src(feat_src)
     graph.dstdata["e_dst"] = self.atten_dst(feat_dst)
     graph.edata["e_sub"] = self.atten_sub(graph.edata["e_sub"])
     graph.edata["e_mul"] = self.atten_mul(graph.edata["e_mul"])
     ## “mlp后相加”代替“concat后mlp”
     graph.edata["e"] = graph.edata.pop("e_sub") + graph.edata.pop("e_mul")
     graph.apply_edges(fn.e_add_u('e', 'e_src', 'e'))
     graph.apply_edges(fn.e_add_v('e', 'e_dst', 'e'))
     graph.srcdata.pop("e_src")
     graph.dstdata.pop("e_dst")
     ## 第一层激活函数
     graph.edata["e"] = F.gelu(graph.edata["e"])
     ## 第二层mlp变换到1
     graph.edata["e"] = self.leaky_relu(self.atten_out(graph.edata["e"]))
     # max pool
     graph.srcdata['h'] = F.gelu(self.fc_pool(feat_src))
     graph.apply_edges(fn.e_mul_u('e', 'h', 'h'))
     graph.update_all(fn.copy_e('h', 'm'), fn.max('m', 'neigh'))
     h_neigh = graph.dstdata['neigh']
     # mean pool
     graph.srcdata['h'] = F.gelu(self.fc_pool2(feat_src))
     graph.apply_edges(fn.e_mul_u('e', 'h', 'h'))
     graph.update_all(fn.copy_e('h', 'm'), fn.mean('m', 'neigh'))
     h_neigh2 = graph.dstdata['neigh']
     # concat
     rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) + self.fc_neigh2(h_neigh2)
     # mlps
     if len(self.out_mlp) > 0:
         for layer in self.out_mlp:
             o = layer(F.gelu(rst))
             rst = rst + o
     return rst
예제 #19
0
    def forward(self, nf, logits):
        r"""Compute edge softmax.
        Parameters
        ----------
        nf : NodeFlow
        logits : torch.Tensor
            The input edge feature
        Returns
        -------
        Unnormalized scores : torch.Tensor
            This part gives :math:`\exp(z_{ij})`'s
        Normalizer : torch.Tensor
            This part gives :math:`\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})`
        Notes
        -----
            * Input shape: :math:`(N, *, 1)` where * means any number of additional
              dimensions, :math:`N` is the number of edges.
            * Unnormalized scores shape: :math:`(N, *, 1)` where all but the last
              dimension are the same shape as the input.
            * Normalizer shape: :math:`(M, *, 1)` where :math:`M` is the number of
              nodes and all but the first and the last dimensions are the same as
              the input.
        """
        self._logits_name = get_edata_name(nf, self.index, self._logits_name)
        self._max_logits_name = get_ndata_name(nf, self.index + 1, self._max_logits_name)
        self._normalizer_name = get_ndata_name(nf, self.index + 1, self._normalizer_name)

        nf.blocks[self.index].data[self._logits_name] = logits

        # compute the softmax
        nf.block_compute(self.index, fn.copy_edge(self._logits_name, self._logits_name),
                         fn.max(self._logits_name, self._max_logits_name))
        # minus the max and exp
        nf.apply_block(self.index, lambda edges: {
            self._logits_name : torch.exp(edges.data[self._logits_name] - edges.dst[self._max_logits_name])})

        # pop out temporary feature _max_logits, otherwise get_ndata_name could have huge overhead
        nf.layers[self.index + 1].data.pop(self._max_logits_name)
        # compute normalizer
        nf.block_compute(self.index, fn.copy_edge(self._logits_name, self._logits_name),
                         fn.sum(self._logits_name, self._normalizer_name))

        return nf.blocks[self.index].data.pop(self._logits_name), \
               nf.layers[self.index + 1].data.pop(self._normalizer_name)
예제 #20
0
파일: model.py 프로젝트: junxincai/HMSG
    def forward(self, g, feat):
        with g.local_scope():
            if self.aggre_type == 'attention':
                h_src = self.feat_drop(feat[0]).view(-1, self.num_heads, self.in_size)
                h_dst = self.feat_drop(feat[1]).view(-1, self.num_heads, self.in_size)
                el = (h_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
                # er = (h_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
                g.srcdata.update({'ft': h_src, 'el': el})
                # g.srcdata.update({'ft': h_src, 'er': er})
                g.apply_edges(fn.copy_u('el', 'e'))
                # g.apply_edges(fn.u_add_v('el', 'er', 'e'))
                e = self.leaky_relu(g.edata.pop('e'))

                g.edata['a'] = self.attn_drop(edge_softmax(g, e))
                g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
                rst = g.dstdata['ft'].flatten(1)
                if self.residual:
                    rst = rst + h_dst.flatten(1)
                if self.activation:
                    rst = self.activation(rst)

            elif self.aggre_type == 'mean':
                h_src = self.feat_drop(feat[0]).view(-1, self.in_size*self.num_heads)
                h_dst = self.feat_drop(feat[1]).view(-1, self.in_size * self.num_heads)
                g.srcdata['ft'] = h_src
                g.update_all(fn.copy_u('ft', 'm'), fn.mean('m', 'ft'))
                rst = g.dstdata['ft'] # + h_dst


            elif self.aggre_type == 'pool':
                h_src = self.feat_drop(feat[0]).view(-1, self.in_size*self.num_heads)
                h_dst = self.feat_drop(feat[1]).view(-1, self.in_size * self.num_heads)
                g.srcdata['ft'] = F.relu(self.fc_pool(h_src))
                g.update_all(fn.copy_u('ft', 'm'), fn.max('m', 'ft'))
                rst = g.dstdata['ft'] #+ h_dst
            return rst
예제 #21
0
    def forward(self, graph: dgl.DGLGraph,
                feats: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Args:
            graph: the graph
            feats: node features with node type as key and the corresponding
                features as value. Each tensor is of shape (N, D) where N is the number
                of nodes of the corresponding node type, and D is the feature size.

        Returns:
            updated node features. Each tensor is of shape (N, D) where N is the number
            of nodes of the corresponding node type, and D is the feature size.

        """
        graph = graph.local_var()

        # assign data
        for nt, ft in feats.items():
            graph.nodes[nt].data.update({"ft": ft})

        for et in self.etypes:
            # option 1
            graph[et].update_all(fn.copy_u("ft", "m"),
                                 fn.mean("m", "mean"),
                                 etype=et)
            graph[et].update_all(fn.copy_u("ft", "m"),
                                 fn.max("m", "max"),
                                 etype=et)

            nt = et[2]
            graph.apply_nodes(self._concatenate_node_feat, ntype=nt)

            # copy update feature from new_ft to ft
            graph.nodes[nt].data.update({"ft": graph.nodes[nt].data["new_ft"]})

        return {nt: graph.nodes[nt].data["ft"] for nt in feats}
    def forward(self, g, h):
        h_in = h  # for residual connection

        if self.dgl_builtin == False:
            h = self.dropout(h)
            g.ndata['h'] = h
            #g.update_all(fn.copy_src(src='h', out='m'),
            #             self.aggregator,
            #             self.nodeapply)
            if self.aggregator_type == 'maxpool':
                g.ndata['h'] = self.aggregator.linear(g.ndata['h'])
                g.ndata['h'] = self.aggregator.activation(g.ndata['h'])
                g.update_all(fn.copy_src('h', 'm'), fn.max('m', 'c'),
                             self.nodeapply)
            elif self.aggregator_type == 'lstm':
                g.update_all(fn.copy_src(src='h', out='m'), self.aggregator,
                             self.nodeapply)
            elif self.aggregator_type == 'sumpool':
                P = torch.clamp(self.P, 1, 100)
                g.ndata['h_pow'] = torch.abs(g.ndata['h']).pow(P)
                g.update_all(fn.copy_src('h_pow', 'm'), fn.sum('m', 'c'),
                             self.nodeapply)
            else:
                g.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'c'),
                             self.nodeapply)
            h = g.ndata['h']
        else:
            h = self.sageconv(g, h)

        if self.batch_norm:
            h = self.batchnorm_h(h)

        if self.residual:
            h = h_in + h  # residual connection

        return h
예제 #23
0
  def forward(self, nf):
    if self.preprocess:
      for i in range(nf.num_layers):
        h = nf.layers[i].data.pop('features')
        neigh = nf.layers[i].data.pop('neigh')
        if self.dropout:
          h = self.dropout(h)
        h = self.fc_self(h) + self.fc_neigh(neigh)
        skip_start = (0 == self.n_layers - 1)
        if skip_start:
          h = torch.cat((h, self.activation(h)), dim=1)
        else:
          h = self.activation(h)
        nf.layers[i].data['h'] = h
    else:
      for lid in range(nf.num_layers):
        nf.layers[lid].data['h'] = nf.layers[lid].data.pop('features')

    for lid, layer in enumerate(self.layers):
      for i in range(lid, nf.num_layers - 1):
        h = nf.layers[i].data.pop('h')
        h = self.dropout(h)
        nf.layers[i].data['h'] = h
        if self.aggregator_type == 'mean':
          nf.block_compute(i,
                           fn.copy_src(src='h', out='m'),
                           fn.mean('m', 'neigh'),
                           layer)
        elif self.aggregator_type == 'gcn':
          nf.block_compute(i,
                           fn.copy_src(src='h', out='m'),
                           fn.sum('m', 'neigh'),
                           layer)
        elif self.aggregator_type == 'pool':
          nf.block_compute(i,
                           fn.copy_src(src='h', out='m'),
                           fn.max('m', 'neigh'),
                           layer)
        elif self.aggregator_type == 'lstm':
          reducer = self.reducer[i]
          def _reducer(self, nodes):
            m = nodes.mailbox['m'] # (B, L, D)
            batch_size = m.shape[0]
            h = (m.new_zeros((1, batch_size, self._in_feats)),
                 m.new_zeros((1, batch_size, self._in_feats)))
            _, (rst, _) = reducer(m, h)
            return {'neigh': rst.squeeze(0)}

          nf.block_compute(i,
                           fn.copy_src(src='h', out='m'),
                           _reducer,
                           layer)
        else:
          raise KeyError('Aggregator type {} not recognized.'.format(self.aggregator_type))
      # set up new feat
      for i in range(lid + 1, nf.num_layers):
        h = nf.layers[i].data.pop('activation')
        nf.layers[i].data['h'] = h

    h = nf.layers[nf.num_layers - 1].data.pop('h')
    return h
예제 #24
0
    def forward(self, graph, feat):
        r"""

        Description
        -----------
        Compute GraphSAGE layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor or pair of torch.Tensor
            If a torch.Tensor is given, it represents the input feature of shape
            :math:`(N, D_{in})`
            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
            is size of output feature.
        """
        with graph.local_scope():
            if isinstance(feat, tuple):
                feat_src = self.feat_drop(feat[0])
                feat_dst = self.feat_drop(feat[1])
            else:
                feat_src = feat_dst = self.feat_drop(feat)
                if graph.is_block:
                    feat_dst = feat_src[:graph.number_of_dst_nodes()]

            h_self = feat_dst

            # Handle the case of graphs without edges
            if graph.number_of_edges() == 0:
                graph.dstdata['neigh'] = torch.zeros(
                    feat_dst.shape[0], self._in_src_feats).to(feat_dst)

            if self._aggre_type == 'mean':
                graph.srcdata['h'] = feat_src
                graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            elif self._aggre_type == 'gcn':
                check_eq_shape(feat)
                graph.srcdata['h'] = feat_src
                graph.dstdata['h'] = feat_dst  # same as above if homogeneous
                graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
                # divide in_degrees
                degs = graph.in_degrees().to(feat_dst)
                h_neigh = (graph.dstdata['neigh'] +
                           graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
            elif self._aggre_type == 'pool':
                graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
                graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            elif self._aggre_type == 'lstm':
                graph.srcdata['h'] = feat_src
                graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
                h_neigh = graph.dstdata['neigh']
            elif self._aggre_type == 'ginmean':
                graph.srcdata['h'] = feat_src
                graph.update_all(fn.copy_src('h', 'm'),
                                 self._gin_reducer('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            elif self._aggre_type == 'cheb':

                def unnLaplacian(feat, D_invsqrt_left, D_invsqrt_right, graph):
                    """ Operation Feat * D^-1/2 A D^-1/2 但是如果写成矩阵乘法:D^-1/2 A D^-1/2 Feat"""
                    #tmp = torch.zeros((D_invsqrt.shape[0],D_invsqrt.shape[0])).to(graph.device)
                    # sparse tensor没有broadcast机制,最后还依赖于srcnode在feat中从0开始连续排布
                    #print("adj : ",graph.adj(transpose=False,ctx = graph.device).shape)
                    #graph.srcdata['h'] = (torch.mm((graph.adj(transpose=False,ctx = graph.device)),(feat * D_invsqrt)))*D_invsqrt[::graph.number_of_dst_nodes()]
                    #graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'))
                    #return graph.srcdata['h']
                    graph.srcdata[
                        'h'] = feat * D_invsqrt_right  # feat is srcfeat
                    graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
                    return graph.dstdata.pop('h') * D_invsqrt_left

                D_invsqrt_right = torch.pow(
                    graph.out_degrees().float().clamp(min=1),
                    -0.5).unsqueeze(-1)
                D_invsqrt_left = torch.pow(
                    graph.in_degrees().float().clamp(min=1),
                    -0.5).unsqueeze(-1)
                #print("D_invsqrt shape: ",D_invsqrt.shape)
                #print(graph.__dict__)
                #print(dir(graph))
                #graph.srcdata['h']=feat_src
                #graph.dstdata['h']=feat_dst
                #g = dgl.to_homogeneous(graph,ndata=['h'])
                #dgl._ffi.base.DGLError: Expect number of features to match number of nodes (len(u)). Got 70 and 76 instead.
                #print(g)
                # since the block is different every time so it's safe to call dgl's method every time instead of calculating the l_m ahead
                try:
                    lambda_max = laplacian_lambda_max(graph)
                except BaseException:
                    # if the largest eigenvalue is not found
                    dgl_warning(
                        "Largest eigonvalue not found, using default value 2 for lambda_max",
                        RuntimeWarning)
                    lambda_max = torch.tensor(2)  # .to(feat.device)
                if isinstance(lambda_max, list):
                    lambda_max = torch.tensor(lambda_max)  # .to(feat.device)
                if lambda_max.dim() == 1:
                    lambda_max = lambda_max.unsqueeze(-1)  # (B,) to (B, 1)
                # broadcast from (B, 1) to (N, 1)
                # lambda_max = lambda_max * torch.ones((feat.shape[0],1))
                #re_norm = (2 / lambda_max ) * torch.ones((graph.number_of_dst_nodes(),1)).to(graph.device)
                re_norm = (2 / lambda_max.to(graph.device)) * torch.ones(
                    (graph.number_of_dst_nodes(), 1), device=graph.device)
                self._cheb_Xt = X_0 = feat_dst
                graph.srcdata[
                    'h'] = feat_src * D_invsqrt_right  # feat is srcfeat
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
                X_1 = -re_norm * graph.dstdata['h'] * D_invsqrt_left + X_0 * (
                    re_norm - 1)
                self._cheb_Xt = torch.cat((self._cheb_Xt, X_1.float()), 1)
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(
                    self._aggre_type))

            # GraphSAGE GCN does not require fc_self.
            if self._aggre_type == 'gcn':
                rst = self.fc_neigh(h_neigh)
            elif self._aggre_type == 'ginmean':
                rst = (1 + self.eps) * h_self + h_neigh
                rst = self.fc_gin(rst)
                if self.norm is not None:
                    rst = self.norm(rst)
                return rst
            elif self._aggre_type == 'cheb':
                rst = self._cheb_linear(self._cheb_Xt)
            else:
                rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

            # activation
            if self.activation is not None:
                rst = self.activation(rst)
            # normalization
            if self.norm is not None:
                rst = self.norm(rst)
            return rst
예제 #25
0
파일: gcn.py 프로젝트: zeta1999/DiscoBERT
# Graph Conv and Relational Graph Conv
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F
from allennlp.common import FromParams

gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce_sum = fn.sum(msg='m', out='h')
gcn_reduce_max = fn.max(msg='m', out='h')
# gcn_reduce_u_mul_v = fn.u_mul_v('m', 'h')
from depricated.archival_gnns import GraphEncoder


class NodeApplyModule(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
        h = self.activation(h)
        return {'h': h}


class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
예제 #26
0
    def forward(self, g, features):
        '''        
        Inputs:
            g: 
                The graph
            features: 
                H^{l}, i.e. Node features with shape [num_nodes, features_per_node]
                
        Returns:
            rst:
                H^{l+1}, i.e. Node embeddings of the l+1 layer (depth) with the 
                shape [num_nodes, hidden_per_node]
                
        Variables:
            msg_func: 
                Message function, i.e. What to be aggregated 
                (e.g. Sending node embeddings)
            reduce_func: 
                Reduce function, i.e. How to aggregate 
                (e.g. Summing neighbor embeddings)
                
        Notice: 'h' means node feature/embedding itself, 'm' means node's mailbox
        '''
        # create an independent instance of the graph to manipulate
        g = g.local_var()

        # H^{k-1}_{v}
        h_self = features

        # calculate H^{k}_{N(v)} in line 4 of the algorithm 1
        # based on different aggregators
        if self._aggre_type == 'mean':
            g.ndata['h'] = features
            msg_func = fn.copy_src('h', 'm')
            reduce_func = fn.mean('m', 'neigh')
            g.update_all(msg_func, reduce_func)
            # h_neigh is H^{k}_{N(v)}
            h_neigh = g.ndata.pop('neigh')
        elif self._aggre_type == 'gcn':
            # part of equation (2) in the paper
            g.ndata['h'] = features
            msg_func = fn.copy_src('h', 'm')
            reduce_func = fn.sum('m', 'neigh')
            g.update_all(msg_func, reduce_func)
            h_neigh = g.ndata.pop('neigh')
            # H^{k-1}_{v} U H^{k-1}_{u} in equation (2)
            # g.ndata.pop('neigh') represents {H^{k-1}_{u} for u /belongs N(v)}
            # g.dstdata['h'] represents {H^{k-1}_{v}}
            h_neigh = h_neigh + g.ndata.pop('h')
            # divide in_degrees: MEAN() operation in equation (2)
            degs = g.in_degrees().to(features)
            # Notice: h_neigh is more than H^{k}_{N(u)}
            h_neigh = h_neigh / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'pool':
            g.ndata['h'] = F.relu(self.fc_pool(features))
            msg_func = fn.copy_src('h', 'm')
            reduce_func = fn.max('m', 'neigh')
            g.update_all(msg_func, reduce_func)
            # h_neigh is H^{k}_{N(v)}
            h_neigh = g.ndata.pop('neigh')
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(
                self._aggre_type))

        # calculate H^{k}_{v} in line 5 of the algorithm 1
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

        # activation
        if self._activation_func is not None:
            rst = self._activation_func(rst)

        # normalization in line 7 of the algorithm 1
        # l2_norm = torch.norm(rst, p=2, dim=1)
        # l2_norm = l2_norm.unsqueeze(1)
        # rst = torch.div(rst, l2_norm)

        return rst
예제 #27
0
def test_update_all_multi_fallback():
    # create a graph with zero in degree nodes
    g = dgl.DGLGraph()
    g.add_nodes(10)
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    g.ndata['h'] = th.randn(10, D)
    g.edata['w1'] = th.randn(16,)
    g.edata['w2'] = th.randn(16, D)
    def _mfunc_hxw1(edges):
        return {'m1' : edges.src['h'] * th.unsqueeze(edges.data['w1'], 1)}
    def _mfunc_hxw2(edges):
        return {'m2' : edges.src['h'] * edges.data['w2']}
    def _rfunc_m1(nodes):
        return {'o1' : th.sum(nodes.mailbox['m1'], 1)}
    def _rfunc_m2(nodes):
        return {'o2' : th.sum(nodes.mailbox['m2'], 1)}
    def _rfunc_m1max(nodes):
        return {'o3' : th.max(nodes.mailbox['m1'], 1)[0]}
    def _afunc(nodes):
        ret = {}
        for k, v in nodes.data.items():
            if k.startswith('o'):
                ret[k] = 2 * v
        return ret
    # compute ground truth
    g.update_all(_mfunc_hxw1, _rfunc_m1, _afunc)
    o1 = g.ndata.pop('o1')
    g.update_all(_mfunc_hxw2, _rfunc_m2, _afunc)
    o2 = g.ndata.pop('o2')
    g.update_all(_mfunc_hxw1, _rfunc_m1max, _afunc)
    o3 = g.ndata.pop('o3')
    # v2v spmv
    g.update_all(fn.src_mul_edge(src='h', edge='w1', out='m1'),
                 fn.sum(msg='m1', out='o1'),
                 _afunc)
    assert U.allclose(o1, g.ndata.pop('o1'))
    # v2v fallback to e2v
    g.update_all(fn.src_mul_edge(src='h', edge='w2', out='m2'),
                 fn.sum(msg='m2', out='o2'),
                 _afunc)
    assert U.allclose(o2, g.ndata.pop('o2'))
    # v2v fallback to degree bucketing
    g.update_all(fn.src_mul_edge(src='h', edge='w1', out='m1'),
                 fn.max(msg='m1', out='o3'),
                 _afunc)
    assert U.allclose(o3, g.ndata.pop('o3'))
    # multi builtins, both v2v spmv
    g.update_all([fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w1', out='m2')],
                 [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
                 _afunc)
    assert U.allclose(o1, g.ndata.pop('o1'))
    assert U.allclose(o1, g.ndata.pop('o2'))
    # multi builtins, one v2v spmv, one fallback to e2v
    g.update_all([fn.src_mul_edge(src='h', edge='w1', out='m1'), fn.src_mul_edge(src='h', edge='w2', out='m2')],
                 [fn.sum(msg='m1', out='o1'), fn.sum(msg='m2', out='o2')],
                 _afunc)
    assert U.allclose(o1, g.ndata.pop('o1'))
    assert U.allclose(o2, g.ndata.pop('o2'))
    # multi builtins, one v2v spmv, one fallback to e2v, one fallback to degree-bucketing
    g.update_all([fn.src_mul_edge(src='h', edge='w1', out='m1'),
                  fn.src_mul_edge(src='h', edge='w2', out='m2'),
                  fn.src_mul_edge(src='h', edge='w1', out='m3')],
                 [fn.sum(msg='m1', out='o1'),
                  fn.sum(msg='m2', out='o2'),
                  fn.max(msg='m3', out='o3')],
                 _afunc)
    assert U.allclose(o1, g.ndata.pop('o1'))
    assert U.allclose(o2, g.ndata.pop('o2'))
    assert U.allclose(o3, g.ndata.pop('o3'))
예제 #28
0
    def forward(self, agg_graph: dgl.DGLGraph, prop_graph: dgl.DGLGraph,
                traversal_order, new_node_ids) -> torch.Tensor:
        tg = agg_graph.local_var()
        pg = prop_graph.local_var()

        nfeat = tg.ndata["nfeat"]
        # h_self = nfeat
        h_self = self.encode_time(nfeat, tg.ndata["timestamp"])
        tg.ndata["nfeat"] = h_self
        tg.edata["efeat"] = self.fc_edge(tg.edata["efeat"])
        # efeat = tg.edata["efeat"]
        # tg.apply_edges(lambda edges: {
        #     "efeat":
        #     torch.cat((edges.src["nfeat"], edges.data["efeat"]), dim=1)
        # })
        # tg.edata["efeat"] = self.encode_time(tg.edata["efeat"], tg.edata["timestamp"])
        degs = tg.ndata["degree"]

        # agg_graph aggregation
        if self._agg_type == "pool":
            tg.edata["efeat"] = F.relu(self.fc_pool(tg.edata["efeat"]))
            tg.update_all(fn.u_add_e("nfeat", "efeat", "m"),
                          fn.max("m", "neigh"))
            h_neigh = tg.ndata["neigh"]
        elif self._agg_type in ["mean", "gcn", "lstm"]:
            tg.update_all(fn.u_add_e("nfeat", "efeat", "m"),
                          fn.sum("m", "neigh"))
            h_neigh = tg.ndata["neigh"]
        else:
            raise KeyError("Aggregator type {} not recognized.".format(
                self._agg_type))

        pg.ndata["neigh"] = h_neigh
        # prop_graph propagation
        if False:
            if self._agg_type == "mean":
                pg.prop_nodes(traversal_order,
                              message_func=fn.copy_src("neigh", "tmp"),
                              reduce_func=fn.sum("tmp", "acc"))
                h_neigh = h_neigh + pg.ndata["acc"]
                h_neigh = h_neigh / degs.unsqueeze(-1)
            elif self._agg_type == "gcn":
                pg.prop_nodes(traversal_order,
                              message_func=fn.copy_src("neigh", "tmp"),
                              reduce_func=fn.sum("tmp", "acc"))
                h_neigh = h_neigh + pg.ndata["acc"]
                h_neigh = (h_self + h_neigh) / (degs.unsqueeze(-1) + 1)
            elif self._agg_type == "pool":
                pg.prop_nodes(traversal_order,
                              message_func=fn.copy_src("neigh", "tmp"),
                              reduce_func=fn.max("tmp", "acc"))
                h_neigh = torch.max(h_neigh, pg.ndata["acc"])
            elif self._agg_type == "lstm":
                h_neighs = [
                    self._lstm_reducer(h_neigh[ids]) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
        else:
            if self._agg_type == "mean":
                h_neighs = [
                    torch.cumsum(h_neigh[ids], dim=0) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
                h_neigh = h_neigh / degs.unsqueeze(-1)
            elif self._agg_type == "gcn":
                h_neighs = [
                    torch.cumsum(h_neigh[ids], dim=0) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
                h_neigh = (h_self + h_neigh) / (degs.unsqueeze(-1) + 1)
            elif self._agg_type == "pool":
                h_neighs = [
                    torch.cummax(h_neigh[ids], dim=0) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]
            elif self._agg_type == "lstm":
                h_neighs = [
                    self._lstm_reducer(h_neigh[ids]) for ids in new_node_ids
                ]
                h_neighs = torch.cat(h_neighs, dim=0)
                ridx = torch.arange(h_neighs.shape[0])
                ridx[np.concatenate(new_node_ids)] = torch.arange(
                    h_neighs.shape[0])
                h_neigh = h_neighs[ridx]

        if self._agg_type == "gcn":
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        return rst
예제 #29
0
# Sends a message of node feature h
# Equivalent to => return {'m': edges.src['h']}
# msg = fn.copy_src(src='h', out='m')
#
# def reduce(nodes):
#     accum = torch.mean(nodes.mailbox['m'], 1)
#     return {'h': accum}

# def msg_func(edges):
#     return {'m': torch.mul(edges.data['feat'], edges.src['h'])}

msg_func = fn.u_mul_e('h', 'feat', 'm')
reduce_mean = fn.mean('m', 'h')
reduce_sum = fn.sum('m', 'h')
reduce_max = fn.max('m', 'h')

# def reduce(nodes):
#     accum = torch.sum(nodes.mailbox['m'], 1)
#     return {'h': accum}


class NodeApplyModule(nn.Module):
    # Update node feature h_v with (Wh_v+b)
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, node):
        h = self.linear(node.data['h'])
        return {'h': h}
예제 #30
0
    def forward(self, g, features):
        '''        
        Inputs:
            g: 
                The graph
            features: 
                H^{l}, BLOCK.SRC and BLOCK.DST features in tuple with shape
                [N_{src}, D_{in_{src}] and [N_{dst}, D_{in_{dst}]
                where 'D_{in}' is size of input feature
                
        Returns:
            rst:
                H^{l+1}, Node embeddings of the l+1 layer (depth) with the 
                shape [N_{dst}, D_{out}]
                
        Variables:
            msg_func: 
                Message function, i.e. What to be aggregated 
                (e.g. Sending node embeddings)
            reduce_func: 
                Reduce function, i.e. How to aggregate 
                (e.g. Summing neighbor embeddings)
                
        Notice: 'h' means node feature/embedding itself, 'm' means node's mailbox
        '''

        # create an independent instance of the graph to manipulate
        g = g.local_var()

        # split (feature_src, feature_dst)
        feat_src = features[0]
        feat_dst = features[1]

        # H^{k-1}_{u}
        h_self = feat_dst

        # calculate H^{k}_{N(u)} in line 11 of the algorithm 2
        # different aggregators: aggregate neighbor (block.src) information
        # in this case, g.srcdata and g.dstdata will be more convenient, they
        # should be identical to g.ndata
        if self._aggre_type == 'mean':
            g.srcdata['h'] = feat_src
            msg_func = fn.copy_src('h', 'm')
            reduce_func = fn.mean('m', 'neigh')
            g.update_all(msg_func, reduce_func)
            # h_neigh is H^{k}_{N(u)}
            h_neigh = g.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            # check whether feat_src and feat_dst has the same shape
            # otherwise we can't sum later
            dgl.utils.check_eq_shape(features)
            # part of equation (2) in the paper
            g.srcdata['h'] = feat_src
            g.dstdata['h'] = feat_dst
            msg_func = fn.copy_src('h', 'm')
            reduce_func = fn.sum('m', 'neigh')
            g.update_all(msg_func, reduce_func)
            h_neigh = g.dstdata['neigh']
            # H^{k-1}_{v} U H^{k-1}_{u} in equation (2)
            # g.dstdata['neigh'] represents BLOCK.DST with aggregation from SRC
            # g.dstdata['h'] represents original BLOCK.DST without aggregation
            h_neigh = h_neigh + g.dstdata['h']
            # divide in_degrees: MEAN() operation in equation (2)
            degs = g.in_degrees().to(feat_dst)
            # Notice: h_neigh is more than H^{k}_{N(u)}
            h_neigh = h_neigh / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'pool':
            # equation (3) in the paper
            g.srcdata['h'] = self.relu(self.fc_pool(feat_src))
            msg_func = fn.copy_src('h', 'm')
            reduce_func = fn.max('m', 'neigh')
            g.update_all(msg_func, reduce_func)
            # h_neigh is H^{k}_{N(u)}
            h_neigh = g.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(
                self._aggre_type))

        # calculate H^{k}_{v} in line 11 of the algorithm 2
        # Notice: GCN aggregator is different than in others, see equation (2)
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            # line 12 of the algorithm 2
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

        # activation
        if self._activation_func is not None:
            rst = self._activation_func(rst)

        # normalization in line 13 of the algorithm 2
        # l2_norm = torch.norm(rst, p=2, dim=1)
        # l2_norm = l2_norm.unsqueeze(1)
        # rst = torch.div(rst, l2_norm)

        return rst