Beispiel #1
0
def broadcast_sender_nodes_to_edges(
    hypergraph: HypergraphsTuple
) -> torch.Tensor:
    """ Sort of equivalent to tf.gather(hypergraph.nodes, hypergraph.senders) """
    if hypergraph.zero_padding:
        return hypergraph.nodes_with_zero_last()[hypergraph.senders].reshape(
            hypergraph.total_n_edge, -1
        )
    else:
        return hypergraph.nodes[hypergraph.senders].reshape(
            hypergraph.total_n_edge, -1
        )
Beispiel #2
0
    def forward(self, hypergraph: HypergraphsTuple):
        globals_to_collect = []

        if self._use_edges:
            globals_to_collect.append(self._edges_aggregator(hypergraph))

        if self._use_nodes:
            globals_to_collect.append(self._nodes_aggregator(hypergraph))

        if self._use_globals:
            globals_to_collect.append(hypergraph.globals)

        # Concatenate so we maintain number of globals, just extend
        # dimensionality of each global
        collected_globals = torch.cat(globals_to_collect, -1)
        updated_globals = self._global_model(collected_globals)
        return hypergraph.replace(globals=updated_globals)
Beispiel #3
0
    def forward(self, hypergraph: HypergraphsTuple):
        nodes_to_collect = []

        if self._use_received_edges:
            nodes_to_collect.append(
                self._received_edges_aggregator(hypergraph)
            )

        if self._use_sent_edges:
            nodes_to_collect.append(self._sent_edges_aggregator(hypergraph))

        if self._use_nodes:
            nodes_to_collect.append(hypergraph.nodes)

        if self._use_globals:
            nodes_to_collect.append(broadcast_globals_to_nodes(hypergraph))

        # Concatenate so we maintain number of node, just extend
        # dimensionality of each node
        collected_nodes = torch.cat(nodes_to_collect, -1)
        updated_nodes = self._node_model(collected_nodes)
        return hypergraph.replace(nodes=updated_nodes)
Beispiel #4
0
    def forward(self, hypergraph: HypergraphsTuple):
        edges_to_collect = []

        if self._use_edges:
            edges_to_collect.append(hypergraph.edges)

        if self._use_receiver_nodes:
            edges_to_collect.append(
                broadcast_receiver_nodes_to_edges(hypergraph)
            )

        if self._use_sender_nodes:
            edges_to_collect.append(
                broadcast_sender_nodes_to_edges(hypergraph)
            )

        if self._use_globals:
            edges_to_collect.append(broadcast_globals_to_edges(hypergraph))

        # Concatenate so we maintain number of edges, just extend
        # dimensionality of each edge
        collected_edges = torch.cat(edges_to_collect, -1)
        updated_edges = self._edge_model(collected_edges)
        return hypergraph.replace(edges=updated_edges)
Beispiel #5
0
 def forward(self, hypergraph: HypergraphsTuple):
     return hypergraph.replace(
         edges=self._edge_model(hypergraph.edges),
         nodes=self._node_model(hypergraph.nodes),
         globals=self._global_model(hypergraph.globals),
     )
def hypergraph_view_to_hypergraphs_tuple(
    hypergraph: HypergraphView,
    receiver_k: int,
    sender_k: int,
    node_features: Optional[torch.Tensor] = None,
    edge_features: Optional[torch.Tensor] = None,
    global_features: Optional[torch.Tensor] = None,
    pad_func: Callable[[list, int], list] = pad_with_obj_up_to_k,
) -> HypergraphsTuple:
    """
    Convert a Delete-Relaxation Task to a Hypergraphs Tuple (with
    node/edge/global features)

    :param hypergraph: HypergraphView
    :param receiver_k: maximum number of receivers for a hyperedge, receivers will be repeated to fit k
    :param sender_k: maximum number of senders for a hyperedge, senders will be repeated to fit k
    :param node_features: node features as a torch.Tensor
    :param edge_features: edge features as a torch.Tensor
    :param global_features: global features as a torch.Tensor
    :param pad_func: function for handling different number of sender/receiver nodes
    :return: parsed HypergraphsTuple
    """
    # Receivers are the additive effects for each action
    receivers = torch.LongTensor(
        [
            pad_func(
                [
                    # FIXME
                    hypergraph.node_to_idx(atom)
                    for atom in sorted(hyperedge.receivers)
                ],
                receiver_k,
            )
            for hyperedge in hypergraph.hyperedges
        ]
    )

    # Senders are preconditions for each action
    senders = torch.LongTensor(
        [
            pad_func(
                [
                    # FIXME
                    hypergraph.node_to_idx(atom)
                    for atom in sorted(hyperedge.senders)
                ],
                sender_k,
            )
            for hyperedge in hypergraph.hyperedges
        ]
    )

    # Validate features
    _validate_features(node_features, len(hypergraph.nodes), "Nodes")
    _validate_features(edge_features, len(hypergraph.hyperedges), "Edges")
    if global_features is not None:
        _validate_features(global_features, len(global_features), "Global")

    params = {
        N_NODE: torch.LongTensor([len(hypergraph.nodes)]),
        N_EDGE: torch.LongTensor([len(hypergraph.hyperedges)]),
        # Hyperedge connection information
        RECEIVERS: receivers,
        SENDERS: senders,
        # Features, set to None
        NODES: node_features,
        EDGES: edge_features,
        GLOBALS: global_features,
        ZERO_PADDING: pad_func == pad_with_obj_up_to_k,
    }

    return HypergraphsTuple(**params)
def merge_hypergraphs_tuple(
    graphs_tuple_list: List[HypergraphsTuple]
) -> HypergraphsTuple:
    """
    Merge multiple HypergraphsTuple (each representing one hypergraph)
    together into one - i.e. batch them up
    """
    assert len(graphs_tuple_list) > 0

    def _stack_features(attr_name, force_matrix=True):
        """ Stack matrices on top of each other """
        features = [
            getattr(h_tup, attr_name)
            for h_tup in graphs_tuple_list
            if getattr(h_tup, attr_name) is not None
        ]
        if len(features) == 0:
            return None
        else:
            stacked = torch.cat(features)
            if force_matrix and len(stacked.shape) == 1:
                stacked = stacked.reshape(-1, 1)
            return stacked

    # New tuple attributes
    n_node, n_edge, receivers, senders, nodes, edges, globals_ = (
        _stack_features(attr_name, force_matrix)
        for attr_name, force_matrix in [
            (N_NODE, False),
            (N_EDGE, False),
            (RECEIVERS, True),
            (SENDERS, True),
            (NODES, True),
            (EDGES, True),
            (GLOBALS, True),
        ]
    )

    # Check padding consistent across hypergraphs
    assert len(set(h.zero_padding for h in graphs_tuple_list)) == 1
    zero_padding = graphs_tuple_list[0].zero_padding

    # Check general sizes have been maintained
    assert len(n_node) == len(n_edge) == len(graphs_tuple_list)
    assert receivers.shape[0] == senders.shape[0] == torch.sum(n_edge)

    if edges is not None:
        assert edges.shape[0] == torch.sum(n_edge)
    if nodes is not None:
        assert nodes.shape[0] == torch.sum(n_node)
    if globals_ is not None:
        assert globals_.shape[0] == len(graphs_tuple_list)

    return HypergraphsTuple(
        **{
            N_NODE: n_node,
            N_EDGE: n_edge,
            # Hyperedge connection information
            RECEIVERS: receivers,
            SENDERS: senders,
            # Features, turn them to tensors
            NODES: nodes,
            EDGES: edges,
            GLOBALS: globals_,
            ZERO_PADDING: zero_padding,
        }
    )