Exemplo n.º 1
0
    def forward(self, graph: dgl.DGLHeteroGraph, **kwargs):
        def message_func(edges):
            et = edges.canonical_etype[1]
            out = self.edges_modules[et](edges.src['data'])
            return {'m': out}

        def reduce_func(nodes):
            return {'data': torch.mean(nodes.mailbox['m'], dim=1)}

        with graph.local_scope():
            for et in graph.etypes:
                graph[et].update_all(
                    message_func=message_func,
                    reduce_func=reduce_func,
                    etype=et,
                )
        return graph
Exemplo n.º 2
0
    def _propagate_user_to_item(self, block: DGLHeteroGraph) -> th.Tensor:
        with block.local_scope():
            for etype in self._rating_set:
                block.apply_edges(lambda edges: self._compute_message_user_to_item(edges, etype), etype=etype)
                block.update_all(dgl_fn.copy_e("m", f"m_{etype}"), dgl_fn.mean(f"m_{etype}", f"h_{etype}"), etype=etype)
            item_features: th.Tensor = block.dstnodes["item"].data["item_features"]
            all_feature_on_item = [item_features]
            for rating in self._rating_set:
                feature_name = f"h_{rating}"
                if feature_name in block.dstnodes["item"].data:
                    all_feature_on_item.append(block.dstnodes["item"].data[feature_name])
                else:
                    all_feature_on_item.append(th.zeros(
                        item_features.shape[0], self._embedding_dim,
                        dtype=item_features.dtype, device=item_features.device))

            return self._agg_activation(self._item_aggregate_layer(th.cat(all_feature_on_item, dim=1)))
Exemplo n.º 3
0
    def forward(self, graph: dgl.DGLHeteroGraph):
        # embedding block
        if self.embed is not None:
            graph.ndata['data'] = self.embed(graph.ndata['data'])
        src_embed = graph.ndata['data']  # [B*N, 100]
        graph.ndata['data'] = self.embed_modules(src_embed)  # [B*N, 64]
        # MLP-GNN

        graph = self.gnn_modules(graph)
        graph.ndata['data'] = torch.cat((src_embed, graph.ndata['data']),
                                        dim=-1)  # [B*N, 164]
        with graph.local_scope():
            graph.ndata['scoring_out'] = self.score_mlp(graph.ndata['data'])
            weights = dgl.softmax_nodes(graph, 'scoring_out')
            node_embed = self.transform_mlp(graph.ndata['data'])
            graph.ndata['node_embed'] = weights * node_embed  # [B*N, 8]
            node_embed = dgl.sum_nodes(graph, 'node_embed')
        node_embed = self.out_linear(node_embed)
        return node_embed
Exemplo n.º 4
0
    def _propagate_item_to_user(self, block: DGLHeteroGraph) -> th.Tensor:
        with block.local_scope():
            block.srcnodes["item"].data["item_id_embedding"] = \
                self._item_id_embedding_layer(block.srcnodes["item"].data[dgl.NID])
            for etype in [f"rev-{rating}" for rating in self._rating_set]:
                block.apply_edges(lambda edges: self._compute_message_item_to_user(edges, etype), etype=etype)
                block.update_all(dgl_fn.copy_e("m", f"m_{etype}"), dgl_fn.mean(f"m_{etype}", f"h_{etype}"), etype=etype)
            user_feature = block.dstnodes["user"].data["user_features"]
            all_features_on_user = [user_feature]
            for rating in self._rating_set:
                feature_name = f"h_rev-{rating}"
                if feature_name in block.dstnodes["user"].data:
                    all_features_on_user.append(block.dstnodes["user"].data[feature_name])
                else:
                    all_features_on_user.append(th.zeros(
                        user_feature.shape[0], self._embedding_dim,
                        dtype=user_feature.dtype, device=user_feature.device))

            return self._agg_activation(self._user_aggregate_layer(th.cat(all_features_on_user, dim=1)))
Exemplo n.º 5
0
 def forward(self, decode_graph: dgl.DGLHeteroGraph, node_representations: Tensor):
     with decode_graph.local_scope():
         decode_graph.ndata["h"] = node_representations
         decode_graph.apply_edges(dgl.function.u_dot_v("h", "h", "logits"))
         return decode_graph.edata["logits"]
Exemplo n.º 6
0
 def forward(self,g: dgl.DGLHeteroGraph,h,etype='interact'):
     with g.local_scope():
         g.nodes['user'].data['h']=self.dropout(h['user'])
         g.nodes['item'].data['h'] = self.dropout(h['item'])
         g.apply_edges(self.apply_edges,etype=etype)
         return g.edges[etype].data['score']*5