def broadcast_sender_nodes_to_edges( hypergraph: HypergraphsTuple ) -> torch.Tensor: """ Sort of equivalent to tf.gather(hypergraph.nodes, hypergraph.senders) """ if hypergraph.zero_padding: return hypergraph.nodes_with_zero_last()[hypergraph.senders].reshape( hypergraph.total_n_edge, -1 ) else: return hypergraph.nodes[hypergraph.senders].reshape( hypergraph.total_n_edge, -1 )
def forward(self, hypergraph: HypergraphsTuple): globals_to_collect = [] if self._use_edges: globals_to_collect.append(self._edges_aggregator(hypergraph)) if self._use_nodes: globals_to_collect.append(self._nodes_aggregator(hypergraph)) if self._use_globals: globals_to_collect.append(hypergraph.globals) # Concatenate so we maintain number of globals, just extend # dimensionality of each global collected_globals = torch.cat(globals_to_collect, -1) updated_globals = self._global_model(collected_globals) return hypergraph.replace(globals=updated_globals)
def forward(self, hypergraph: HypergraphsTuple): nodes_to_collect = [] if self._use_received_edges: nodes_to_collect.append( self._received_edges_aggregator(hypergraph) ) if self._use_sent_edges: nodes_to_collect.append(self._sent_edges_aggregator(hypergraph)) if self._use_nodes: nodes_to_collect.append(hypergraph.nodes) if self._use_globals: nodes_to_collect.append(broadcast_globals_to_nodes(hypergraph)) # Concatenate so we maintain number of node, just extend # dimensionality of each node collected_nodes = torch.cat(nodes_to_collect, -1) updated_nodes = self._node_model(collected_nodes) return hypergraph.replace(nodes=updated_nodes)
def forward(self, hypergraph: HypergraphsTuple): edges_to_collect = [] if self._use_edges: edges_to_collect.append(hypergraph.edges) if self._use_receiver_nodes: edges_to_collect.append( broadcast_receiver_nodes_to_edges(hypergraph) ) if self._use_sender_nodes: edges_to_collect.append( broadcast_sender_nodes_to_edges(hypergraph) ) if self._use_globals: edges_to_collect.append(broadcast_globals_to_edges(hypergraph)) # Concatenate so we maintain number of edges, just extend # dimensionality of each edge collected_edges = torch.cat(edges_to_collect, -1) updated_edges = self._edge_model(collected_edges) return hypergraph.replace(edges=updated_edges)
def forward(self, hypergraph: HypergraphsTuple): return hypergraph.replace( edges=self._edge_model(hypergraph.edges), nodes=self._node_model(hypergraph.nodes), globals=self._global_model(hypergraph.globals), )
def hypergraph_view_to_hypergraphs_tuple( hypergraph: HypergraphView, receiver_k: int, sender_k: int, node_features: Optional[torch.Tensor] = None, edge_features: Optional[torch.Tensor] = None, global_features: Optional[torch.Tensor] = None, pad_func: Callable[[list, int], list] = pad_with_obj_up_to_k, ) -> HypergraphsTuple: """ Convert a Delete-Relaxation Task to a Hypergraphs Tuple (with node/edge/global features) :param hypergraph: HypergraphView :param receiver_k: maximum number of receivers for a hyperedge, receivers will be repeated to fit k :param sender_k: maximum number of senders for a hyperedge, senders will be repeated to fit k :param node_features: node features as a torch.Tensor :param edge_features: edge features as a torch.Tensor :param global_features: global features as a torch.Tensor :param pad_func: function for handling different number of sender/receiver nodes :return: parsed HypergraphsTuple """ # Receivers are the additive effects for each action receivers = torch.LongTensor( [ pad_func( [ # FIXME hypergraph.node_to_idx(atom) for atom in sorted(hyperedge.receivers) ], receiver_k, ) for hyperedge in hypergraph.hyperedges ] ) # Senders are preconditions for each action senders = torch.LongTensor( [ pad_func( [ # FIXME hypergraph.node_to_idx(atom) for atom in sorted(hyperedge.senders) ], sender_k, ) for hyperedge in hypergraph.hyperedges ] ) # Validate features _validate_features(node_features, len(hypergraph.nodes), "Nodes") _validate_features(edge_features, len(hypergraph.hyperedges), "Edges") if global_features is not None: _validate_features(global_features, len(global_features), "Global") params = { N_NODE: torch.LongTensor([len(hypergraph.nodes)]), N_EDGE: torch.LongTensor([len(hypergraph.hyperedges)]), # Hyperedge connection information RECEIVERS: receivers, SENDERS: senders, # Features, set to None NODES: node_features, EDGES: edge_features, GLOBALS: global_features, ZERO_PADDING: pad_func == pad_with_obj_up_to_k, } return HypergraphsTuple(**params)
def merge_hypergraphs_tuple( graphs_tuple_list: List[HypergraphsTuple] ) -> HypergraphsTuple: """ Merge multiple HypergraphsTuple (each representing one hypergraph) together into one - i.e. batch them up """ assert len(graphs_tuple_list) > 0 def _stack_features(attr_name, force_matrix=True): """ Stack matrices on top of each other """ features = [ getattr(h_tup, attr_name) for h_tup in graphs_tuple_list if getattr(h_tup, attr_name) is not None ] if len(features) == 0: return None else: stacked = torch.cat(features) if force_matrix and len(stacked.shape) == 1: stacked = stacked.reshape(-1, 1) return stacked # New tuple attributes n_node, n_edge, receivers, senders, nodes, edges, globals_ = ( _stack_features(attr_name, force_matrix) for attr_name, force_matrix in [ (N_NODE, False), (N_EDGE, False), (RECEIVERS, True), (SENDERS, True), (NODES, True), (EDGES, True), (GLOBALS, True), ] ) # Check padding consistent across hypergraphs assert len(set(h.zero_padding for h in graphs_tuple_list)) == 1 zero_padding = graphs_tuple_list[0].zero_padding # Check general sizes have been maintained assert len(n_node) == len(n_edge) == len(graphs_tuple_list) assert receivers.shape[0] == senders.shape[0] == torch.sum(n_edge) if edges is not None: assert edges.shape[0] == torch.sum(n_edge) if nodes is not None: assert nodes.shape[0] == torch.sum(n_node) if globals_ is not None: assert globals_.shape[0] == len(graphs_tuple_list) return HypergraphsTuple( **{ N_NODE: n_node, N_EDGE: n_edge, # Hyperedge connection information RECEIVERS: receivers, SENDERS: senders, # Features, turn them to tensors NODES: nodes, EDGES: edges, GLOBALS: globals_, ZERO_PADDING: zero_padding, } )