def create_graph_core(pass_global_to_edge: bool, pass_global_to_node: bool): return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=Aggregator("add"), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=Aggregator("add"), node_aggregator=Aggregator("add"), ), pass_global_to_edge=pass_global_to_edge, pass_global_to_node=pass_global_to_node, )
def test_init_agg_global_block_requires_grad(): # test AggregatingGlobalBlock global_attr = torch.randn(10, 3) edge_attr = torch.randn(20, 3) edges = torch.randint(0, 40, torch.Size([2, 20])) node_attr = torch.randn(40, 2) node_idx = torch.randint(0, 3, torch.Size([40])) edge_idx = torch.randint(0, 3, torch.Size([20])) global_model = AggregatingGlobalBlock(MLP(8, 16, 10), Aggregator("mean"), Aggregator("mean")) out = global_model( global_attr=global_attr, node_attr=node_attr, edge_attr=edge_attr, edges=edges, node_idx=node_idx, edge_idx=edge_idx, ) for p in global_model.parameters(): assert p.requires_grad print(list(global_model.parameters()))
def __init__( self, latent_sizes=(128, 128, 1), output_sizes=(1, 1, 1), depths=(1, 1, 1), layer_norm: bool = True, dropout: float = None, pass_global_to_edge: bool = True, pass_global_to_node: bool = True, ): super().__init__() self.config = { "latent_size": { "node": latent_sizes[1], "edge": latent_sizes[0], "global": latent_sizes[2], "core_node_block_depth": depths[0], "core_edge_block_depth": depths[1], "core_global_block_depth": depths[2], }, "output_size": { "edge": output_sizes[0], "node": output_sizes[1], "global": output_sizes[2], }, "node_block_aggregator": "add", "global_block_to_node_aggregator": "add", "global_block_to_edge_aggregator": "add", "pass_global_to_edge": pass_global_to_edge, "pass_global_to_node": pass_global_to_node, } def mlp(*layer_sizes): return Flex(MLP)(Flex.d(), *layer_sizes, layer_norm=layer_norm, dropout=dropout) self.encoder = GraphEncoder( EdgeBlock(mlp(latent_sizes[0])), NodeBlock(mlp(latent_sizes[1])), GlobalBlock(mlp(latent_sizes[2])), ) edge_layers = [self.config["latent_size"]["edge"] ] * self.config["latent_size"]["core_edge_block_depth"] node_layers = [self.config["latent_size"]["node"] ] * self.config["latent_size"]["core_node_block_depth"] global_layers = [ self.config["latent_size"]["global"] ] * self.config["latent_size"]["core_global_block_depth"] self.core = GraphCore( AggregatingEdgeBlock(mlp(*edge_layers)), AggregatingNodeBlock( mlp(*node_layers), Aggregator(self.config["node_block_aggregator"])), AggregatingGlobalBlock( mlp(*global_layers), edge_aggregator=Aggregator( self.config["global_block_to_edge_aggregator"]), node_aggregator=Aggregator( self.config["global_block_to_node_aggregator"]), ), pass_global_to_edge=self.config["pass_global_to_edge"], pass_global_to_node=self.config["pass_global_to_node"], ) self.decoder = GraphEncoder( EdgeBlock(mlp(latent_sizes[0])), NodeBlock(mlp(latent_sizes[1])), GlobalBlock(mlp(latent_sizes[2])), ) self.output_transform = GraphEncoder( EdgeBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[0])), NodeBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[1])), GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[2])), )
def test_aggregators(method): block = Aggregator(method) idx = torch.tensor([0, 0, 0, 1, 1, 1, 1]) x = torch.tensor([0, 1, 10, 3, 4, 55, 6]) out = block(x, idx, dim=0) print(out)
def test_invalid_method(): with pytest.raises(ValueError): block = Aggregator("not a method")
def test_aggregators_2d(method): block = Aggregator(method) idx = torch.tensor([0, 0, 0, 1, 1, 1, 1]) x = torch.randn((7, 3)) out = block(x, idx, dim=0) print(out)