Ejemplo n.º 1
0
def test_flex_block_chain(x):

    model = torch.nn.Sequential(
        Flex(torch.nn.Linear)(Flex.d(), 16),
        Flex(torch.nn.Linear)(Flex.d(), 32),
        Flex(torch.nn.Linear)(Flex.d(), 64),
    )

    data = torch.randn((10, x))
    out = model(data)
    assert out.shape[1] == 64
def test_flexible_agg_node_block_mult_agg():
    net = AggregatingNodeBlock(
        Flex(MLP)(Flex.d(), 25),
        edge_aggregator=Flex(MultiAggregator)(Flex.d(), aggregators=["add"]),
    )
    n_edges = 10
    e_feat = 5
    n_nodes = 20
    n_feat = 7
    edata = torch.randn((n_edges, e_feat), dtype=torch.float)
    ndata = torch.randn((n_nodes, n_feat), dtype=torch.float)
    edges = torch.randint(0, 2, (2, n_edges), dtype=torch.long)
    net(node_attr=ndata, edge_attr=edata, edges=edges)
Ejemplo n.º 3
0
 def __init__(
     self,
     layers,
     dropout,
     layer_norm,
     aggregator,
     aggregator_activation: Type[torch.nn.Module],
 ):
     super().__init__(
         Flex(MLP)(Flex.d(), *layers, dropout=dropout, layer_norm=layer_norm),
         Flex(MultiAggregator)(
             Flex.d(), aggregators=aggregator, activation=aggregator_activation
         ),
     )
Ejemplo n.º 4
0
def test_train_shortest_path():
    graphs = [
        generate_shorest_path_example(100, 0.01, 1000) for _ in range(10)
    ]
    input_data = [
        GraphData.from_networkx(g, feature_key="_features") for g in graphs
    ]
    target_data = [
        GraphData.from_networkx(g, feature_key="_target") for g in graphs
    ]

    loader = GraphDataLoader(input_data,
                             target_data,
                             batch_size=32,
                             shuffle=True)

    agg = lambda: Flex(MultiAggregator)(Flex.d(),
                                        ["add", "mean", "max", "min"])

    network = Network()

    for input_batch, _ in loader:
        network(input_batch, 10)
        break

    loss_fn = torch.nn.BCELoss()
    optimizer = torch.optim.AdamW(network.parameters())
    for _ in range(10):
        for input_batch, target_batch in loader:
            output = network(input_batch, 10)[0]
            x, y = output.x, target_batch.x
            loss = loss_fn(x.flatten(), y[:, 0].flatten())
            loss.backward()
            print(loss.detach())
            optimizer.step()
Ejemplo n.º 5
0
def test_flex_block():
    flex_linear = Flex(torch.nn.Linear)
    model = flex_linear(Flex.d(), 11)
    print(model.__str__())
    print(model.__repr__())
    x = torch.randn((30, 55))
    model(x)
    print(model.__str__())
    print(model.__repr__())
Ejemplo n.º 6
0
 def __init__(
     self,
     layers,
     dropout,
     layer_norm,
     edge_aggregator,
     node_aggregator,
     aggregator_activation,
 ):
     super().__init__(
         mlp=Flex(MLP)(Flex.d(), *layers, dropout=dropout, layer_norm=layer_norm),
         edge_aggregator=Flex(MultiAggregator)(
             Flex.d(), aggregators=edge_aggregator, activation=aggregator_activation
         ),
         node_aggregator=Flex(MultiAggregator)(
             Flex.d(), aggregators=node_aggregator, activation=aggregator_activation
         ),
     )
Ejemplo n.º 7
0
 def _init_encoder(self):
     return GraphEncoder(
         EdgeBlock(
             Flex(MLP)(
                 Flex.d(),
                 self.config["sizes"]["latent"]["edge"],
                 dropout=self.config["dropout"],
             )),
         NodeBlock(
             Flex(MLP)(
                 Flex.d(),
                 self.config["sizes"]["latent"]["node"],
                 dropout=self.config["dropout"],
             )),
         GlobalBlock(
             Flex(MLP)(
                 Flex.d(),
                 self.config["sizes"]["latent"]["global"],
                 dropout=self.config["dropout"],
             )),
     )
Ejemplo n.º 8
0
def test_flex_block_custom_position(x):
    class FooBlock(torch.nn.Module):
        def __init__(self, a, b):
            super().__init__()
            self.block = torch.nn.Linear(a, b)

        def forward(self, steps, data):
            return self.block(data)

    model = Flex(FooBlock)(Flex.d(1), 16)
    data = torch.randn((10, x))
    model("arg0", data)
Ejemplo n.º 9
0
def test_():
    net = AggregatingNodeBlock(
        Flex(MLP)(Flex.d(), 25),
        edge_aggregator=Flex(MultiAggregator)(Flex.d(), aggregators=["add"]),
    )
    n_edges = 10
    e_feat = 5
    n_nodes = 20
    n_feat = 7
    edata = torch.randn((n_edges, e_feat), dtype=torch.float)
    ndata = torch.randn((n_nodes, n_feat), dtype=torch.float)
    edges = torch.randint(0, 2, (2, n_edges), dtype=torch.long)
    out = net(node_attr=ndata, edge_attr=edata, edges=edges)

    print(out)
    var = out
    output_nodes = (
        (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
    )
    print(output_nodes)
    next = output_nodes[0].next_functions
Ejemplo n.º 10
0
 def _init_out_transform(self):
     return GraphEncoder(
         EdgeBlock(
             torch.nn.Sequential(
                 Flex(torch.nn.Linear)(Flex.d(),
                                       self.config["sizes"]["out"]["edge"]),
                 self.config["sizes"]["out"]["activation"](),
             )),
         NodeBlock(
             torch.nn.Sequential(
                 Flex(torch.nn.Linear)(Flex.d(),
                                       self.config["sizes"]["out"]["node"]),
                 self.config["sizes"]["out"]["activation"](),
             )),
         GlobalBlock(
             torch.nn.Sequential(
                 Flex(torch.nn.Linear)(
                     Flex.d(), self.config["sizes"]["out"]["global"]),
                 self.config["sizes"]["out"]["activation"](),
             )),
     )
Ejemplo n.º 11
0
 def create_graph_core(pass_global_to_edge: bool,
                       pass_global_to_node: bool):
     return GraphCore(
         AggregatingEdgeBlock(
             torch.nn.Sequential(
                 Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                 Flex(torch.nn.Linear)(Flex.d(), 1),
             )),
         AggregatingNodeBlock(
             torch.nn.Sequential(
                 Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                 Flex(torch.nn.Linear)(Flex.d(), 1),
             ),
             edge_aggregator=Aggregator("add"),
         ),
         AggregatingGlobalBlock(
             torch.nn.Sequential(
                 Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                 Flex(torch.nn.Linear)(Flex.d(), 1),
             ),
             edge_aggregator=Aggregator("add"),
             node_aggregator=Aggregator("add"),
         ),
         pass_global_to_edge=pass_global_to_edge,
         pass_global_to_node=pass_global_to_node,
     )
Ejemplo n.º 12
0
    def _init_core(self):
        edge_layers = [self.config["sizes"]["latent"]["edge"]
                       ] * self.config["sizes"]["latent"]["edge_depth"]
        node_layers = [self.config["sizes"]["latent"]["node"]
                       ] * self.config["sizes"]["latent"]["node_depth"]
        global_layers = [self.config["sizes"]["latent"]["global"]
                         ] * self.config["sizes"]["latent"]["global_depth"]

        return GraphCore(
            AggregatingEdgeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *edge_layers,
                              dropout=self.config["dropout"],
                              layer_norm=True), )),
            AggregatingNodeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *node_layers,
                              dropout=self.config["dropout"],
                              layer_norm=True), ),
                Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["node_block_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
            ),
            AggregatingGlobalBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *global_layers,
                              dropout=self.config["dropout"],
                              layer_norm=True), ),
                edge_aggregator=Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["global_block_to_edge_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
                node_aggregator=Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["global_block_to_node_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
            ),
            pass_global_to_edge=self.config["pass_global_to_edge"],
            pass_global_to_node=self.config["pass_global_to_node"],
        )
Ejemplo n.º 13
0
    def create_graph_core_multi_agg(pass_global_to_edge: bool,
                                    pass_global_to_node: bool):
        agg = lambda: Flex(MultiAggregator)(Flex.d(),
                                            ["add", "mean", "max", "min"])

        return GraphCore(
            AggregatingEdgeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              5,
                              5,
                              layer_norm=True,
                              activation=torch.nn.LeakyReLU),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                )),
            AggregatingNodeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              5,
                              5,
                              layer_norm=True,
                              activation=torch.nn.LeakyReLU),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                ),
                edge_aggregator=agg(),
            ),
            AggregatingGlobalBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              5,
                              5,
                              layer_norm=True,
                              activation=torch.nn.LeakyReLU),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                ),
                edge_aggregator=agg(),
                node_aggregator=agg(),
            ),
            pass_global_to_edge=pass_global_to_edge,
            pass_global_to_node=pass_global_to_node,
        )
Ejemplo n.º 14
0
 def __init__(self, size: int, activation: Type[torch.nn.Module]):
     super().__init__()
     layers = [Flex(torch.nn.Linear)(Flex.d(), size)]
     if activation is not None:
         layers.append(activation())
     self.layer = torch.nn.Sequential(*layers)
Ejemplo n.º 15
0
    def __init__(
        self,
        latent_sizes=(16, 16, 1),
        out_sizes=(1, 1, 1),
        latent_depths=(1, 1, 1),
        dropout: float = None,
        pass_global_to_edge: bool = True,
        pass_global_to_node: bool = True,
        activation=defaults.activation,
        out_activation=defaults.activation,
        edge_to_node_aggregators=tuple(["add", "max", "mean", "min"]),
        edge_to_global_aggregators=tuple(["add", "max", "mean", "min"]),
        node_to_global_aggregators=tuple(["add", "max", "mean", "min"]),
        aggregator_activation=defaults.activation,
    ):
        super().__init__()
        self.config = {
            "sizes": {
                "latent": {
                    "edge": latent_sizes[0],
                    "node": latent_sizes[1],
                    "global": latent_sizes[2],
                    "edge_depth": latent_depths[0],
                    "node_depth": latent_depths[1],
                    "global_depth": latent_depths[2],
                },
                "out": {
                    "edge": out_sizes[0],
                    "node": out_sizes[1],
                    "global": out_sizes[2],
                    "activation": out_activation,
                },
            },
            "activation": activation,
            "dropout": dropout,
            "node_block_aggregator": edge_to_node_aggregators,
            "global_block_to_node_aggregator": node_to_global_aggregators,
            "global_block_to_edge_aggregator": edge_to_global_aggregators,
            "aggregator_activation": aggregator_activation,
            "pass_global_to_edge": pass_global_to_edge,
            "pass_global_to_node": pass_global_to_node,
        }

        ###########################
        # encoder
        ###########################

        self.encoder = self._init_encoder()
        self.core = self._init_core()
        self.decoder = self._init_encoder()
        self.output_transform = self._init_out_transform()

        self.output_transform = GraphEncoder(
            EdgeBlock(
                torch.nn.Sequential(
                    Flex(torch.nn.Linear)(Flex.d(), 1), torch.nn.Sigmoid())),
            NodeBlock(
                torch.nn.Sequential(
                    Flex(torch.nn.Linear)(Flex.d(), 1), torch.nn.Sigmoid())),
            GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), 1)),
        )
Ejemplo n.º 16
0
 def __init__(self, layers, dropout: float, layer_norm: bool):
     super().__init__(
         Flex(MLP)(Flex.d(), *layers, dropout=dropout, layer_norm=layer_norm)
     )
Ejemplo n.º 17
0
 def mlp(*layer_sizes):
     return Flex(MLP)(Flex.d(),
                      *layer_sizes,
                      layer_norm=layer_norm,
                      dropout=dropout)
Ejemplo n.º 18
0
 def __init__(self, size: int, dropout: float, activation: Type[torch.nn.Module]):
     super().__init__(
         Flex(MLP)(Flex.d(), size, dropout=dropout, activation=activation)
     )
Ejemplo n.º 19
0
    def __init__(
        self,
        latent_sizes=(128, 128, 1),
        output_sizes=(1, 1, 1),
        depths=(1, 1, 1),
        layer_norm: bool = True,
        dropout: float = None,
        pass_global_to_edge: bool = True,
        pass_global_to_node: bool = True,
    ):
        super().__init__()
        self.config = {
            "latent_size": {
                "node": latent_sizes[1],
                "edge": latent_sizes[0],
                "global": latent_sizes[2],
                "core_node_block_depth": depths[0],
                "core_edge_block_depth": depths[1],
                "core_global_block_depth": depths[2],
            },
            "output_size": {
                "edge": output_sizes[0],
                "node": output_sizes[1],
                "global": output_sizes[2],
            },
            "node_block_aggregator": "add",
            "global_block_to_node_aggregator": "add",
            "global_block_to_edge_aggregator": "add",
            "pass_global_to_edge": pass_global_to_edge,
            "pass_global_to_node": pass_global_to_node,
        }

        def mlp(*layer_sizes):
            return Flex(MLP)(Flex.d(),
                             *layer_sizes,
                             layer_norm=layer_norm,
                             dropout=dropout)

        self.encoder = GraphEncoder(
            EdgeBlock(mlp(latent_sizes[0])),
            NodeBlock(mlp(latent_sizes[1])),
            GlobalBlock(mlp(latent_sizes[2])),
        )

        edge_layers = [self.config["latent_size"]["edge"]
                       ] * self.config["latent_size"]["core_edge_block_depth"]
        node_layers = [self.config["latent_size"]["node"]
                       ] * self.config["latent_size"]["core_node_block_depth"]
        global_layers = [
            self.config["latent_size"]["global"]
        ] * self.config["latent_size"]["core_global_block_depth"]

        self.core = GraphCore(
            AggregatingEdgeBlock(mlp(*edge_layers)),
            AggregatingNodeBlock(
                mlp(*node_layers),
                Aggregator(self.config["node_block_aggregator"])),
            AggregatingGlobalBlock(
                mlp(*global_layers),
                edge_aggregator=Aggregator(
                    self.config["global_block_to_edge_aggregator"]),
                node_aggregator=Aggregator(
                    self.config["global_block_to_node_aggregator"]),
            ),
            pass_global_to_edge=self.config["pass_global_to_edge"],
            pass_global_to_node=self.config["pass_global_to_node"],
        )

        self.decoder = GraphEncoder(
            EdgeBlock(mlp(latent_sizes[0])),
            NodeBlock(mlp(latent_sizes[1])),
            GlobalBlock(mlp(latent_sizes[2])),
        )

        self.output_transform = GraphEncoder(
            EdgeBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[0])),
            NodeBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[1])),
            GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[2])),
        )
Ejemplo n.º 20
0
    def __init__(
        self,
        latent_sizes=(128, 128, 1),
        depths=(1, 1, 1),
        dropout: float = None,
        pass_global_to_edge: bool = True,
        pass_global_to_node: bool = True,
        edge_to_node_aggregators=tuple(["add", "max", "mean", "min"]),
        edge_to_global_aggregators=tuple(["add", "max", "mean", "min"]),
        node_to_global_aggregators=tuple(["add", "max", "mean", "min"]),
        aggregator_activation=defaults.activation,
    ):
        super().__init__()
        self.config = {
            "latent_size": {
                "node": latent_sizes[1],
                "edge": latent_sizes[0],
                "global": latent_sizes[2],
                "core_node_block_depth": depths[0],
                "core_edge_block_depth": depths[1],
                "core_global_block_depth": depths[2],
            },
            "node_block_aggregator": edge_to_node_aggregators,
            "global_block_to_node_aggregator": node_to_global_aggregators,
            "global_block_to_edge_aggregator": edge_to_global_aggregators,
            "aggregator_activation": aggregator_activation,
            "pass_global_to_edge": pass_global_to_edge,
            "pass_global_to_node": pass_global_to_node,
        }
        self.encoder = GraphEncoder(
            EdgeBlock(Flex(MLP)(Flex.d(), latent_sizes[0], dropout=dropout)),
            NodeBlock(Flex(MLP)(Flex.d(), latent_sizes[1], dropout=dropout)),
            GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2], dropout=dropout)),
        )

        edge_layers = [self.config["latent_size"]["edge"]
                       ] * self.config["latent_size"]["core_edge_block_depth"]
        node_layers = [self.config["latent_size"]["node"]
                       ] * self.config["latent_size"]["core_node_block_depth"]
        global_layers = [
            self.config["latent_size"]["global"]
        ] * self.config["latent_size"]["core_global_block_depth"]

        self.core = GraphCore(
            AggregatingEdgeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *edge_layers,
                              dropout=dropout,
                              layer_norm=True),
                    # Flex(torch.nn.Linear)(Flex.d(), edge_layers[-1])
                )),
            AggregatingNodeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *node_layers,
                              dropout=dropout,
                              layer_norm=True),
                    # Flex(torch.nn.Linear)(Flex.d(), node_layers[-1])
                ),
                Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["node_block_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
            ),
            AggregatingGlobalBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              *global_layers,
                              dropout=dropout,
                              layer_norm=True),
                    # Flex(torch.nn.Linear)(Flex.d(), global_layers[-1])
                ),
                edge_aggregator=Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["global_block_to_edge_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
                node_aggregator=Flex(MultiAggregator)(
                    Flex.d(),
                    self.config["global_block_to_node_aggregator"],
                    activation=self.config["aggregator_activation"],
                ),
            ),
            pass_global_to_edge=self.config["pass_global_to_edge"],
            pass_global_to_node=self.config["pass_global_to_node"],
        )

        self.decoder = GraphEncoder(
            EdgeBlock(
                Flex(MLP)(Flex.d(),
                          latent_sizes[0],
                          latent_sizes[0],
                          dropout=dropout)),
            NodeBlock(
                Flex(MLP)(Flex.d(),
                          latent_sizes[1],
                          latent_sizes[1],
                          dropout=dropout)),
            GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2])),
        )

        self.output_transform = GraphEncoder(
            EdgeBlock(Flex(torch.nn.Linear)(Flex.d(), 1)),
            NodeBlock(Flex(torch.nn.Linear)(Flex.d(), 1)),
            GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), 1)),
        )
Ejemplo n.º 21
0
class Networks:
    """Networks that will be used in the tests."""

    n = NamedNetwork

    linear_block = n(
        "linear",
        lambda: torch.nn.Sequential(torch.nn.Linear(5, 16), torch.nn.ReLU(),
                                    torch.nn.Linear(16, 1)),
    )

    mlp_block = n(
        "mlp",
        lambda: torch.nn.Sequential(
            Flex(MLP)(Flex.d(), 16),
            Flex(torch.nn.Linear)(Flex.d(), 1)),
    )

    node_block = n(
        "node_block",
        lambda: torch.nn.Sequential(
            NodeBlock(Flex(MLP)(Flex.d(), 25, 25, layer_norm=False)),
            Flex(torch.nn.Linear)(Flex.d(), 1),
        ),
    )

    edge_block = n(
        "edge_block",
        lambda: torch.nn.Sequential(
            EdgeBlock(Flex(MLP)(Flex.d(), 25, 25, layer_norm=False)),
            Flex(torch.nn.Linear)(Flex.d(), 1),
        ),
    )

    global_block = n(
        "global_block",
        lambda: torch.nn.Sequential(
            GlobalBlock(Flex(MLP)(Flex.d(), 25, 25, layer_norm=False)),
            Flex(torch.nn.Linear)(Flex.d(), 1),
        ),
    )

    graph_encoder = n(
        "graph_encoder",
        lambda: GraphEncoder(
            EdgeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                )),
            NodeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                )),
            GlobalBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                )),
        ),
    )

    def create_graph_core(pass_global_to_edge: bool,
                          pass_global_to_node: bool):
        return GraphCore(
            AggregatingEdgeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                )),
            AggregatingNodeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                ),
                edge_aggregator=Aggregator("add"),
            ),
            AggregatingGlobalBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(), 5, 5, layer_norm=False),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                ),
                edge_aggregator=Aggregator("add"),
                node_aggregator=Aggregator("add"),
            ),
            pass_global_to_edge=pass_global_to_edge,
            pass_global_to_node=pass_global_to_node,
        )

    graph_core = n("graph_core", create_graph_core)

    def create_graph_core_multi_agg(pass_global_to_edge: bool,
                                    pass_global_to_node: bool):
        agg = lambda: Flex(MultiAggregator)(Flex.d(),
                                            ["add", "mean", "max", "min"])

        return GraphCore(
            AggregatingEdgeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              5,
                              5,
                              layer_norm=True,
                              activation=torch.nn.LeakyReLU),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                )),
            AggregatingNodeBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              5,
                              5,
                              layer_norm=True,
                              activation=torch.nn.LeakyReLU),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                ),
                edge_aggregator=agg(),
            ),
            AggregatingGlobalBlock(
                torch.nn.Sequential(
                    Flex(MLP)(Flex.d(),
                              5,
                              5,
                              layer_norm=True,
                              activation=torch.nn.LeakyReLU),
                    Flex(torch.nn.Linear)(Flex.d(), 1),
                ),
                edge_aggregator=agg(),
                node_aggregator=agg(),
            ),
            pass_global_to_edge=pass_global_to_edge,
            pass_global_to_node=pass_global_to_node,
        )

    graph_core_multi_agg = n("graph_core(multiagg)",
                             create_graph_core_multi_agg)

    @staticmethod
    def reset(net: torch.nn.Module):
        def weight_reset(model):
            for layer in model.children():
                if hasattr(layer, "reset_parameters"):
                    layer.reset_parameters()

        net.apply(weight_reset)
Ejemplo n.º 22
0
def test_flexible_multiaggregator():
    net = Flex(MultiAggregator)(Flex.d(), aggregators=["add"])
    data = torch.randn((10, 5), dtype=torch.float)
    idx = torch.randint(0, 2, (10, ), dtype=torch.long)
    net(data, idx, dim=0, dim_size=20)
    print(net)
Ejemplo n.º 23
0
 def __init__(self):
     super().__init__()
     self.layers = Flex(torch.nn.Linear)(Flex.d(), 5)