Ejemplo n.º 1
0
 def create_graph_core(pass_global_to_edge: bool,
                       pass_global_to_node: bool):
     return GraphCore(
         AggregatingEdgeBlock(
             torch.nn.Sequential(
                 Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                 Flex(torch.nn.Linear)(Flex.d(), 1),
             )),
         AggregatingNodeBlock(
             torch.nn.Sequential(
                 Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                 Flex(torch.nn.Linear)(Flex.d(), 1),
             ),
             edge_aggregator=Aggregator("add"),
         ),
         AggregatingGlobalBlock(
             torch.nn.Sequential(
                 Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                 Flex(torch.nn.Linear)(Flex.d(), 1),
             ),
             edge_aggregator=Aggregator("add"),
             node_aggregator=Aggregator("add"),
         ),
         pass_global_to_edge=pass_global_to_edge,
         pass_global_to_node=pass_global_to_node,
     )
Ejemplo n.º 2
0
def test_init_agg_global_block_requires_grad():
    # test AggregatingGlobalBlock
    global_attr = torch.randn(10, 3)
    edge_attr = torch.randn(20, 3)
    edges = torch.randint(0, 40, torch.Size([2, 20]))
    node_attr = torch.randn(40, 2)
    node_idx = torch.randint(0, 3, torch.Size([40]))
    edge_idx = torch.randint(0, 3, torch.Size([20]))

    global_model = AggregatingGlobalBlock(MLP(8, 16, 10), Aggregator("mean"),
                                          Aggregator("mean"))
    out = global_model(
        global_attr=global_attr,
        node_attr=node_attr,
        edge_attr=edge_attr,
        edges=edges,
        node_idx=node_idx,
        edge_idx=edge_idx,
    )

    for p in global_model.parameters():
        assert p.requires_grad

    print(list(global_model.parameters()))
Ejemplo n.º 3
0
    def _init_core(self):
        edge_layers = [self.config["sizes"]["latent"]["edge"]
                       ] * self.config["sizes"]["latent"]["edge_depth"]
        node_layers = [self.config["sizes"]["latent"]["node"]
                       ] * self.config["sizes"]["latent"]["node_depth"]
        global_layers = [self.config["sizes"]["latent"]["global"]
                         ] * self.config["sizes"]["latent"]["global_depth"]

        return GraphCore(
            AggregatingEdgeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *edge_layers,
                              dropout=self.config["dropout"],
                              layer_norm=True), )),
            AggregatingNodeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *node_layers,
                              dropout=self.config["dropout"],
                              layer_norm=True), ),
                Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["node_block_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
            ),
            AggregatingGlobalBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *global_layers,
                              dropout=self.config["dropout"],
                              layer_norm=True), ),
                edge_aggregator=Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["global_block_to_edge_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
                node_aggregator=Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["global_block_to_node_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
            ),
            pass_global_to_edge=self.config["pass_global_to_edge"],
            pass_global_to_node=self.config["pass_global_to_node"],
        )
Ejemplo n.º 4
0
    def create_graph_core_multi_agg(pass_global_to_edge: bool,
                                    pass_global_to_node: bool):
        agg = lambda: Flex(MultiAggregator)(Flex.d(),
                                            ["add", "mean", "max", "min"])

        return GraphCore(
            AggregatingEdgeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              5,
                              5,
                              layer_norm=True,
                              activation=torch.nn.LeakyReLU),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                )),
            AggregatingNodeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              5,
                              5,
                              layer_norm=True,
                              activation=torch.nn.LeakyReLU),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                ),
                edge_aggregator=agg(),
            ),
            AggregatingGlobalBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              5,
                              5,
                              layer_norm=True,
                              activation=torch.nn.LeakyReLU),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                ),
                edge_aggregator=agg(),
                node_aggregator=agg(),
            ),
            pass_global_to_edge=pass_global_to_edge,
            pass_global_to_node=pass_global_to_node,
        )
Ejemplo n.º 5
0
    def __init__(
        self,
        latent_sizes=(128, 128, 1),
        output_sizes=(1, 1, 1),
        depths=(1, 1, 1),
        layer_norm: bool = True,
        dropout: float = None,
        pass_global_to_edge: bool = True,
        pass_global_to_node: bool = True,
    ):
        super().__init__()
        self.config = {
            "latent_size": {
                "node": latent_sizes[1],
                "edge": latent_sizes[0],
                "global": latent_sizes[2],
                "core_node_block_depth": depths[0],
                "core_edge_block_depth": depths[1],
                "core_global_block_depth": depths[2],
            },
            "output_size": {
                "edge": output_sizes[0],
                "node": output_sizes[1],
                "global": output_sizes[2],
            },
            "node_block_aggregator": "add",
            "global_block_to_node_aggregator": "add",
            "global_block_to_edge_aggregator": "add",
            "pass_global_to_edge": pass_global_to_edge,
            "pass_global_to_node": pass_global_to_node,
        }

        def mlp(*layer_sizes):
            return Flex(MLP)(Flex.d(),
                             *layer_sizes,
                             layer_norm=layer_norm,
                             dropout=dropout)

        self.encoder = GraphEncoder(
            EdgeBlock(mlp(latent_sizes[0])),
            NodeBlock(mlp(latent_sizes[1])),
            GlobalBlock(mlp(latent_sizes[2])),
        )

        edge_layers = [self.config["latent_size"]["edge"]
                       ] * self.config["latent_size"]["core_edge_block_depth"]
        node_layers = [self.config["latent_size"]["node"]
                       ] * self.config["latent_size"]["core_node_block_depth"]
        global_layers = [
            self.config["latent_size"]["global"]
        ] * self.config["latent_size"]["core_global_block_depth"]

        self.core = GraphCore(
            AggregatingEdgeBlock(mlp(*edge_layers)),
            AggregatingNodeBlock(
                mlp(*node_layers),
                Aggregator(self.config["node_block_aggregator"])),
            AggregatingGlobalBlock(
                mlp(*global_layers),
                edge_aggregator=Aggregator(
                    self.config["global_block_to_edge_aggregator"]),
                node_aggregator=Aggregator(
                    self.config["global_block_to_node_aggregator"]),
            ),
            pass_global_to_edge=self.config["pass_global_to_edge"],
            pass_global_to_node=self.config["pass_global_to_node"],
        )

        self.decoder = GraphEncoder(
            EdgeBlock(mlp(latent_sizes[0])),
            NodeBlock(mlp(latent_sizes[1])),
            GlobalBlock(mlp(latent_sizes[2])),
        )

        self.output_transform = GraphEncoder(
            EdgeBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[0])),
            NodeBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[1])),
            GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[2])),
        )
Ejemplo n.º 6
0
    def __init__(
        self,
        latent_sizes=(128, 128, 1),
        depths=(1, 1, 1),
        dropout: float = None,
        pass_global_to_edge: bool = True,
        pass_global_to_node: bool = True,
        edge_to_node_aggregators=tuple(["add", "max", "mean", "min"]),
        edge_to_global_aggregators=tuple(["add", "max", "mean", "min"]),
        node_to_global_aggregators=tuple(["add", "max", "mean", "min"]),
        aggregator_activation=defaults.activation,
    ):
        super().__init__()
        self.config = {
            "latent_size": {
                "node": latent_sizes[1],
                "edge": latent_sizes[0],
                "global": latent_sizes[2],
                "core_node_block_depth": depths[0],
                "core_edge_block_depth": depths[1],
                "core_global_block_depth": depths[2],
            },
            "node_block_aggregator": edge_to_node_aggregators,
            "global_block_to_node_aggregator": node_to_global_aggregators,
            "global_block_to_edge_aggregator": edge_to_global_aggregators,
            "aggregator_activation": aggregator_activation,
            "pass_global_to_edge": pass_global_to_edge,
            "pass_global_to_node": pass_global_to_node,
        }
        self.encoder = GraphEncoder(
            EdgeBlock(Flex(MLP)(Flex.d(), latent_sizes[0], dropout=dropout)),
            NodeBlock(Flex(MLP)(Flex.d(), latent_sizes[1], dropout=dropout)),
            GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2], dropout=dropout)),
        )

        edge_layers = [self.config["latent_size"]["edge"]
                       ] * self.config["latent_size"]["core_edge_block_depth"]
        node_layers = [self.config["latent_size"]["node"]
                       ] * self.config["latent_size"]["core_node_block_depth"]
        global_layers = [
            self.config["latent_size"]["global"]
        ] * self.config["latent_size"]["core_global_block_depth"]

        self.core = GraphCore(
            AggregatingEdgeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *edge_layers,
                              dropout=dropout,
                              layer_norm=True),
                    # Flex(torch.nn.Linear)(Flex.d(), edge_layers[-1])
                )),
            AggregatingNodeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *node_layers,
                              dropout=dropout,
                              layer_norm=True),
                    # Flex(torch.nn.Linear)(Flex.d(), node_layers[-1])
                ),
                Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["node_block_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
            ),
            AggregatingGlobalBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *global_layers,
                              dropout=dropout,
                              layer_norm=True),
                    # Flex(torch.nn.Linear)(Flex.d(), global_layers[-1])
                ),
                edge_aggregator=Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["global_block_to_edge_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
                node_aggregator=Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["global_block_to_node_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
            ),
            pass_global_to_edge=self.config["pass_global_to_edge"],
            pass_global_to_node=self.config["pass_global_to_node"],
        )

        self.decoder = GraphEncoder(
            EdgeBlock(
                Flex(MLP)(Flex.d(),
                          latent_sizes[0],
                          latent_sizes[0],
                          dropout=dropout)),
            NodeBlock(
                Flex(MLP)(Flex.d(),
                          latent_sizes[1],
                          latent_sizes[1],
                          dropout=dropout)),
            GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2])),
        )

        self.output_transform = GraphEncoder(
            EdgeBlock(Flex(torch.nn.Linear)(Flex.d(), 1)),
            NodeBlock(Flex(torch.nn.Linear)(Flex.d(), 1)),
            GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), 1)),
        )