def test_flex_block_chain(x): model = torch.nn.Sequential( Flex(torch.nn.Linear)(Flex.d(), 16), Flex(torch.nn.Linear)(Flex.d(), 32), Flex(torch.nn.Linear)(Flex.d(), 64), ) data = torch.randn((10, x)) out = model(data) assert out.shape[1] == 64
def test_flexible_agg_node_block_mult_agg(): net = AggregatingNodeBlock( Flex(MLP)(Flex.d(), 25), edge_aggregator=Flex(MultiAggregator)(Flex.d(), aggregators=["add"]), ) n_edges = 10 e_feat = 5 n_nodes = 20 n_feat = 7 edata = torch.randn((n_edges, e_feat), dtype=torch.float) ndata = torch.randn((n_nodes, n_feat), dtype=torch.float) edges = torch.randint(0, 2, (2, n_edges), dtype=torch.long) net(node_attr=ndata, edge_attr=edata, edges=edges)
def __init__( self, layers, dropout, layer_norm, aggregator, aggregator_activation: Type[torch.nn.Module], ): super().__init__( Flex(MLP)(Flex.d(), *layers, dropout=dropout, layer_norm=layer_norm), Flex(MultiAggregator)( Flex.d(), aggregators=aggregator, activation=aggregator_activation ), )
def test_train_shortest_path(): graphs = [ generate_shorest_path_example(100, 0.01, 1000) for _ in range(10) ] input_data = [ GraphData.from_networkx(g, feature_key="_features") for g in graphs ] target_data = [ GraphData.from_networkx(g, feature_key="_target") for g in graphs ] loader = GraphDataLoader(input_data, target_data, batch_size=32, shuffle=True) agg = lambda: Flex(MultiAggregator)(Flex.d(), ["add", "mean", "max", "min"]) network = Network() for input_batch, _ in loader: network(input_batch, 10) break loss_fn = torch.nn.BCELoss() optimizer = torch.optim.AdamW(network.parameters()) for _ in range(10): for input_batch, target_batch in loader: output = network(input_batch, 10)[0] x, y = output.x, target_batch.x loss = loss_fn(x.flatten(), y[:, 0].flatten()) loss.backward() print(loss.detach()) optimizer.step()
def test_flex_block(): flex_linear = Flex(torch.nn.Linear) model = flex_linear(Flex.d(), 11) print(model.__str__()) print(model.__repr__()) x = torch.randn((30, 55)) model(x) print(model.__str__()) print(model.__repr__())
def __init__( self, layers, dropout, layer_norm, edge_aggregator, node_aggregator, aggregator_activation, ): super().__init__( mlp=Flex(MLP)(Flex.d(), *layers, dropout=dropout, layer_norm=layer_norm), edge_aggregator=Flex(MultiAggregator)( Flex.d(), aggregators=edge_aggregator, activation=aggregator_activation ), node_aggregator=Flex(MultiAggregator)( Flex.d(), aggregators=node_aggregator, activation=aggregator_activation ), )
def _init_encoder(self): return GraphEncoder( EdgeBlock( Flex(MLP)( Flex.d(), self.config["sizes"]["latent"]["edge"], dropout=self.config["dropout"], )), NodeBlock( Flex(MLP)( Flex.d(), self.config["sizes"]["latent"]["node"], dropout=self.config["dropout"], )), GlobalBlock( Flex(MLP)( Flex.d(), self.config["sizes"]["latent"]["global"], dropout=self.config["dropout"], )), )
def test_flex_block_custom_position(x): class FooBlock(torch.nn.Module): def __init__(self, a, b): super().__init__() self.block = torch.nn.Linear(a, b) def forward(self, steps, data): return self.block(data) model = Flex(FooBlock)(Flex.d(1), 16) data = torch.randn((10, x)) model("arg0", data)
def test_(): net = AggregatingNodeBlock( Flex(MLP)(Flex.d(), 25), edge_aggregator=Flex(MultiAggregator)(Flex.d(), aggregators=["add"]), ) n_edges = 10 e_feat = 5 n_nodes = 20 n_feat = 7 edata = torch.randn((n_edges, e_feat), dtype=torch.float) ndata = torch.randn((n_nodes, n_feat), dtype=torch.float) edges = torch.randint(0, 2, (2, n_edges), dtype=torch.long) out = net(node_attr=ndata, edge_attr=edata, edges=edges) print(out) var = out output_nodes = ( (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var) ) print(output_nodes) next = output_nodes[0].next_functions
def _init_out_transform(self): return GraphEncoder( EdgeBlock( torch.nn.Sequential( Flex(torch.nn.Linear)(Flex.d(), self.config["sizes"]["out"]["edge"]), self.config["sizes"]["out"]["activation"](), )), NodeBlock( torch.nn.Sequential( Flex(torch.nn.Linear)(Flex.d(), self.config["sizes"]["out"]["node"]), self.config["sizes"]["out"]["activation"](), )), GlobalBlock( torch.nn.Sequential( Flex(torch.nn.Linear)( Flex.d(), self.config["sizes"]["out"]["global"]), self.config["sizes"]["out"]["activation"](), )), )
def create_graph_core(pass_global_to_edge: bool, pass_global_to_node: bool): return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=Aggregator("add"), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=Aggregator("add"), node_aggregator=Aggregator("add"), ), pass_global_to_edge=pass_global_to_edge, pass_global_to_node=pass_global_to_node, )
def _init_core(self): edge_layers = [self.config["sizes"]["latent"]["edge"] ] * self.config["sizes"]["latent"]["edge_depth"] node_layers = [self.config["sizes"]["latent"]["node"] ] * self.config["sizes"]["latent"]["node_depth"] global_layers = [self.config["sizes"]["latent"]["global"] ] * self.config["sizes"]["latent"]["global_depth"] return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *edge_layers, dropout=self.config["dropout"], layer_norm=True), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *node_layers, dropout=self.config["dropout"], layer_norm=True), ), Flex(MultiAggregator)( Flex.d(), self.config["node_block_aggregator"], activation=self.config["aggregator_activation"], ), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *global_layers, dropout=self.config["dropout"], layer_norm=True), ), edge_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_edge_aggregator"], activation=self.config["aggregator_activation"], ), node_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_node_aggregator"], activation=self.config["aggregator_activation"], ), ), pass_global_to_edge=self.config["pass_global_to_edge"], pass_global_to_node=self.config["pass_global_to_node"], )
def create_graph_core_multi_agg(pass_global_to_edge: bool, pass_global_to_node: bool): agg = lambda: Flex(MultiAggregator)(Flex.d(), ["add", "mean", "max", "min"]) return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=agg(), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=agg(), node_aggregator=agg(), ), pass_global_to_edge=pass_global_to_edge, pass_global_to_node=pass_global_to_node, )
def __init__(self, size: int, activation: Type[torch.nn.Module]): super().__init__() layers = [Flex(torch.nn.Linear)(Flex.d(), size)] if activation is not None: layers.append(activation()) self.layer = torch.nn.Sequential(*layers)
def __init__( self, latent_sizes=(16, 16, 1), out_sizes=(1, 1, 1), latent_depths=(1, 1, 1), dropout: float = None, pass_global_to_edge: bool = True, pass_global_to_node: bool = True, activation=defaults.activation, out_activation=defaults.activation, edge_to_node_aggregators=tuple(["add", "max", "mean", "min"]), edge_to_global_aggregators=tuple(["add", "max", "mean", "min"]), node_to_global_aggregators=tuple(["add", "max", "mean", "min"]), aggregator_activation=defaults.activation, ): super().__init__() self.config = { "sizes": { "latent": { "edge": latent_sizes[0], "node": latent_sizes[1], "global": latent_sizes[2], "edge_depth": latent_depths[0], "node_depth": latent_depths[1], "global_depth": latent_depths[2], }, "out": { "edge": out_sizes[0], "node": out_sizes[1], "global": out_sizes[2], "activation": out_activation, }, }, "activation": activation, "dropout": dropout, "node_block_aggregator": edge_to_node_aggregators, "global_block_to_node_aggregator": node_to_global_aggregators, "global_block_to_edge_aggregator": edge_to_global_aggregators, "aggregator_activation": aggregator_activation, "pass_global_to_edge": pass_global_to_edge, "pass_global_to_node": pass_global_to_node, } ########################### # encoder ########################### self.encoder = self._init_encoder() self.core = self._init_core() self.decoder = self._init_encoder() self.output_transform = self._init_out_transform() self.output_transform = GraphEncoder( EdgeBlock( torch.nn.Sequential( Flex(torch.nn.Linear)(Flex.d(), 1), torch.nn.Sigmoid())), NodeBlock( torch.nn.Sequential( Flex(torch.nn.Linear)(Flex.d(), 1), torch.nn.Sigmoid())), GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), 1)), )
def __init__(self, layers, dropout: float, layer_norm: bool): super().__init__( Flex(MLP)(Flex.d(), *layers, dropout=dropout, layer_norm=layer_norm) )
def mlp(*layer_sizes): return Flex(MLP)(Flex.d(), *layer_sizes, layer_norm=layer_norm, dropout=dropout)
def __init__(self, size: int, dropout: float, activation: Type[torch.nn.Module]): super().__init__( Flex(MLP)(Flex.d(), size, dropout=dropout, activation=activation) )
def __init__( self, latent_sizes=(128, 128, 1), output_sizes=(1, 1, 1), depths=(1, 1, 1), layer_norm: bool = True, dropout: float = None, pass_global_to_edge: bool = True, pass_global_to_node: bool = True, ): super().__init__() self.config = { "latent_size": { "node": latent_sizes[1], "edge": latent_sizes[0], "global": latent_sizes[2], "core_node_block_depth": depths[0], "core_edge_block_depth": depths[1], "core_global_block_depth": depths[2], }, "output_size": { "edge": output_sizes[0], "node": output_sizes[1], "global": output_sizes[2], }, "node_block_aggregator": "add", "global_block_to_node_aggregator": "add", "global_block_to_edge_aggregator": "add", "pass_global_to_edge": pass_global_to_edge, "pass_global_to_node": pass_global_to_node, } def mlp(*layer_sizes): return Flex(MLP)(Flex.d(), *layer_sizes, layer_norm=layer_norm, dropout=dropout) self.encoder = GraphEncoder( EdgeBlock(mlp(latent_sizes[0])), NodeBlock(mlp(latent_sizes[1])), GlobalBlock(mlp(latent_sizes[2])), ) edge_layers = [self.config["latent_size"]["edge"] ] * self.config["latent_size"]["core_edge_block_depth"] node_layers = [self.config["latent_size"]["node"] ] * self.config["latent_size"]["core_node_block_depth"] global_layers = [ self.config["latent_size"]["global"] ] * self.config["latent_size"]["core_global_block_depth"] self.core = GraphCore( AggregatingEdgeBlock(mlp(*edge_layers)), AggregatingNodeBlock( mlp(*node_layers), Aggregator(self.config["node_block_aggregator"])), AggregatingGlobalBlock( mlp(*global_layers), edge_aggregator=Aggregator( self.config["global_block_to_edge_aggregator"]), node_aggregator=Aggregator( self.config["global_block_to_node_aggregator"]), ), pass_global_to_edge=self.config["pass_global_to_edge"], pass_global_to_node=self.config["pass_global_to_node"], ) self.decoder = GraphEncoder( EdgeBlock(mlp(latent_sizes[0])), NodeBlock(mlp(latent_sizes[1])), GlobalBlock(mlp(latent_sizes[2])), ) self.output_transform = GraphEncoder( EdgeBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[0])), NodeBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[1])), GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[2])), )
def __init__( self, latent_sizes=(128, 128, 1), depths=(1, 1, 1), dropout: float = None, pass_global_to_edge: bool = True, pass_global_to_node: bool = True, edge_to_node_aggregators=tuple(["add", "max", "mean", "min"]), edge_to_global_aggregators=tuple(["add", "max", "mean", "min"]), node_to_global_aggregators=tuple(["add", "max", "mean", "min"]), aggregator_activation=defaults.activation, ): super().__init__() self.config = { "latent_size": { "node": latent_sizes[1], "edge": latent_sizes[0], "global": latent_sizes[2], "core_node_block_depth": depths[0], "core_edge_block_depth": depths[1], "core_global_block_depth": depths[2], }, "node_block_aggregator": edge_to_node_aggregators, "global_block_to_node_aggregator": node_to_global_aggregators, "global_block_to_edge_aggregator": edge_to_global_aggregators, "aggregator_activation": aggregator_activation, "pass_global_to_edge": pass_global_to_edge, "pass_global_to_node": pass_global_to_node, } self.encoder = GraphEncoder( EdgeBlock(Flex(MLP)(Flex.d(), latent_sizes[0], dropout=dropout)), NodeBlock(Flex(MLP)(Flex.d(), latent_sizes[1], dropout=dropout)), GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2], dropout=dropout)), ) edge_layers = [self.config["latent_size"]["edge"] ] * self.config["latent_size"]["core_edge_block_depth"] node_layers = [self.config["latent_size"]["node"] ] * self.config["latent_size"]["core_node_block_depth"] global_layers = [ self.config["latent_size"]["global"] ] * self.config["latent_size"]["core_global_block_depth"] self.core = GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *edge_layers, dropout=dropout, layer_norm=True), # Flex(torch.nn.Linear)(Flex.d(), edge_layers[-1]) )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *node_layers, dropout=dropout, layer_norm=True), # Flex(torch.nn.Linear)(Flex.d(), node_layers[-1]) ), Flex(MultiAggregator)( Flex.d(), self.config["node_block_aggregator"], activation=self.config["aggregator_activation"], ), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *global_layers, dropout=dropout, layer_norm=True), # Flex(torch.nn.Linear)(Flex.d(), global_layers[-1]) ), edge_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_edge_aggregator"], activation=self.config["aggregator_activation"], ), node_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_node_aggregator"], activation=self.config["aggregator_activation"], ), ), pass_global_to_edge=self.config["pass_global_to_edge"], pass_global_to_node=self.config["pass_global_to_node"], ) self.decoder = GraphEncoder( EdgeBlock( Flex(MLP)(Flex.d(), latent_sizes[0], latent_sizes[0], dropout=dropout)), NodeBlock( Flex(MLP)(Flex.d(), latent_sizes[1], latent_sizes[1], dropout=dropout)), GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2])), ) self.output_transform = GraphEncoder( EdgeBlock(Flex(torch.nn.Linear)(Flex.d(), 1)), NodeBlock(Flex(torch.nn.Linear)(Flex.d(), 1)), GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), 1)), )
class Networks: """Networks that will be used in the tests.""" n = NamedNetwork linear_block = n( "linear", lambda: torch.nn.Sequential(torch.nn.Linear(5, 16), torch.nn.ReLU(), torch.nn.Linear(16, 1)), ) mlp_block = n( "mlp", lambda: torch.nn.Sequential( Flex(MLP)(Flex.d(), 16), Flex(torch.nn.Linear)(Flex.d(), 1)), ) node_block = n( "node_block", lambda: torch.nn.Sequential( NodeBlock(Flex(MLP)(Flex.d(), 25, 25, layer_norm=False)), Flex(torch.nn.Linear)(Flex.d(), 1), ), ) edge_block = n( "edge_block", lambda: torch.nn.Sequential( EdgeBlock(Flex(MLP)(Flex.d(), 25, 25, layer_norm=False)), Flex(torch.nn.Linear)(Flex.d(), 1), ), ) global_block = n( "global_block", lambda: torch.nn.Sequential( GlobalBlock(Flex(MLP)(Flex.d(), 25, 25, layer_norm=False)), Flex(torch.nn.Linear)(Flex.d(), 1), ), ) graph_encoder = n( "graph_encoder", lambda: GraphEncoder( EdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), )), NodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), )), GlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), )), ), ) def create_graph_core(pass_global_to_edge: bool, pass_global_to_node: bool): return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=Aggregator("add"), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=Aggregator("add"), node_aggregator=Aggregator("add"), ), pass_global_to_edge=pass_global_to_edge, pass_global_to_node=pass_global_to_node, ) graph_core = n("graph_core", create_graph_core) def create_graph_core_multi_agg(pass_global_to_edge: bool, pass_global_to_node: bool): agg = lambda: Flex(MultiAggregator)(Flex.d(), ["add", "mean", "max", "min"]) return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=agg(), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=agg(), node_aggregator=agg(), ), pass_global_to_edge=pass_global_to_edge, pass_global_to_node=pass_global_to_node, ) graph_core_multi_agg = n("graph_core(multiagg)", create_graph_core_multi_agg) @staticmethod def reset(net: torch.nn.Module): def weight_reset(model): for layer in model.children(): if hasattr(layer, "reset_parameters"): layer.reset_parameters() net.apply(weight_reset)
def test_flexible_multiaggregator(): net = Flex(MultiAggregator)(Flex.d(), aggregators=["add"]) data = torch.randn((10, 5), dtype=torch.float) idx = torch.randint(0, 2, (10, ), dtype=torch.long) net(data, idx, dim=0, dim_size=20) print(net)
def __init__(self): super().__init__() self.layers = Flex(torch.nn.Linear)(Flex.d(), 5)