コード例 #1
0
ファイル: modules.py プロジェクト: marcel-n/egnn
    def forward(self, v, k: Dict = None, q: Dict = None, G=None, **kwargs):
        """Forward pass of the linear layer

        Args:
            G: minibatch of (h**o)graphs
            v: dict of value edge-features
            k: dict of key edge-features
            q: dict of query node-features
        Returns: 
            tensor with new features [B, n_points, n_features_out]
        """
        with G.local_scope():
            # Add node features to local graph scope
            ## We use the stacked tensor representation for attention
            for m, d in self.f_value.structure:
                G.edata[f'v{d}'] = v[f'{d}'].view(-1, self.n_heads,
                                                  m // self.n_heads, 2 * d + 1)
            G.edata['k'] = fiber2head(k,
                                      self.n_heads,
                                      self.f_key,
                                      squeeze=True)
            G.ndata['q'] = fiber2head(q,
                                      self.n_heads,
                                      self.f_key,
                                      squeeze=True)

            # Compute attention weights
            ## Inner product between (key) neighborhood and (query) center
            G.apply_edges(fn.e_dot_v('k', 'q', 'e'))

            ## Apply softmax
            e = G.edata.pop('e')
            if self.new_dgl:
                # in dgl 5.3, e has an extra dimension compared to dgl 4.3
                # the following, we get rid of this be reshaping
                n_edges = G.edata['k'].shape[0]
                e = e.view([n_edges, self.n_heads])
            e = e / np.sqrt(self.f_key.n_features)
            G.edata['a'] = edge_softmax(G, e)

            # Perform attention-weighted message-passing
            for d in self.f_value.degrees:
                G.update_all(self.udf_u_mul_e(d), fn.sum('m', f'out{d}'))

            output = {}
            for m, d in self.f_value.structure:
                output[f'{d}'] = G.ndata[f'out{d}'].view(-1, m, 2 * d + 1)

            return output
コード例 #2
0
ファイル: model_sampling.py プロジェクト: yuk12/dgl
    def forward(self, g, f_feat, b_feat, u_feat, v_feat):
        g.srcnodes['u'].data['h'] = u_feat
        g.srcnodes['v'].data['h'] = v_feat
        g.dstnodes['u'].data['h'] = u_feat[:g.number_of_dst_nodes(ntype='u')]
        g.dstnodes['v'].data['h'] = v_feat[:g.number_of_dst_nodes(ntype='v')]
        g.edges['forward'].data['h'] = f_feat
        g.edges['backward'].data['h'] = b_feat

        # formula 3 and 4 (optimized implementation to save memory)
        g.srcnodes["u"].data.update(
            {'he_u': self.u_linear(g.srcnodes['u'].data['h'])})
        g.srcnodes["v"].data.update(
            {'he_v': self.v_linear(g.srcnodes['v'].data['h'])})
        g.dstnodes["u"].data.update(
            {'he_u': self.u_linear(g.dstnodes['u'].data['h'])})
        g.dstnodes["v"].data.update(
            {'he_v': self.v_linear(g.dstnodes['v'].data['h'])})
        g.edges["forward"].data.update({'he_e': self.e_linear(f_feat)})
        g.edges["backward"].data.update({'he_e': self.e_linear(b_feat)})
        g.apply_edges(
            lambda edges:
            {'he': edges.data['he_e'] + edges.dst['he_u'] + edges.src['he_v']},
            etype='backward')
        g.apply_edges(
            lambda edges:
            {'he': edges.data['he_e'] + edges.src['he_u'] + edges.dst['he_v']},
            etype='forward')
        hf = g.edges["forward"].data['he']
        hb = g.edges["backward"].data['he']
        if self.activation is not None:
            hf = self.activation(hf)
            hb = self.activation(hb)

        # formula 6
        g.apply_edges(lambda edges:
                      {'h_ve': th.cat([edges.src['h'], edges.data['h']], -1)},
                      etype='backward')
        g.apply_edges(lambda edges:
                      {'h_ue': th.cat([edges.src['h'], edges.data['h']], -1)},
                      etype='forward')

        # formula 7, self-attention
        g.srcnodes['u'].data['h_att_u'] = self.W_ATTN_u(
            g.srcnodes['u'].data['h'])
        g.srcnodes['v'].data['h_att_v'] = self.W_ATTN_v(
            g.srcnodes['v'].data['h'])
        g.dstnodes['u'].data['h_att_u'] = self.W_ATTN_u(
            g.dstnodes['u'].data['h'])
        g.dstnodes['v'].data['h_att_v'] = self.W_ATTN_v(
            g.dstnodes['v'].data['h'])

        # Step 1: dot product
        g.apply_edges(fn.e_dot_v('h_ve', 'h_att_u', 'edotv'), etype='backward')
        g.apply_edges(fn.e_dot_v('h_ue', 'h_att_v', 'edotv'), etype='forward')

        # Step 2. softmax
        g.edges['backward'].data['sfm'] = edge_softmax(
            g['backward'], g.edges['backward'].data['edotv'])
        g.edges['forward'].data['sfm'] = edge_softmax(
            g['forward'], g.edges['forward'].data['edotv'])

        # Step 3. Broadcast softmax value to each edge, and then attention is done
        g.apply_edges(
            lambda edges: {'attn': edges.data['h_ve'] * edges.data['sfm']},
            etype='backward')
        g.apply_edges(
            lambda edges: {'attn': edges.data['h_ue'] * edges.data['sfm']},
            etype='forward')

        # Step 4. Aggregate attention to dst,user nodes, so formula 7 is done
        g.update_all(fn.copy_e('attn', 'm'),
                     fn.sum('m', 'agg_u'),
                     etype='backward')
        g.update_all(fn.copy_e('attn', 'm'),
                     fn.sum('m', 'agg_v'),
                     etype='forward')

        # formula 5
        h_nu = self.W_u(g.dstnodes['u'].data['agg_u'])
        h_nv = self.W_v(g.dstnodes['v'].data['agg_v'])
        if self.activation is not None:
            h_nu = self.activation(h_nu)
            h_nv = self.activation(h_nv)

        # Dropout
        hf = self.dropout(hf)
        hb = self.dropout(hb)
        h_nu = self.dropout(h_nu)
        h_nv = self.dropout(h_nv)

        # formula 8
        hu = th.cat([self.Vu(g.dstnodes['u'].data['h']), h_nu], -1)
        hv = th.cat([self.Vv(g.dstnodes['v'].data['h']), h_nv], -1)

        return hf, hb, hu, hv