def create_graph_core(pass_global_to_edge: bool, pass_global_to_node: bool): return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=Aggregator("add"), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=Aggregator("add"), node_aggregator=Aggregator("add"), ), pass_global_to_edge=pass_global_to_edge, pass_global_to_node=pass_global_to_node, )
def test_init_agg_global_block_requires_grad(): # test AggregatingGlobalBlock global_attr = torch.randn(10, 3) edge_attr = torch.randn(20, 3) edges = torch.randint(0, 40, torch.Size([2, 20])) node_attr = torch.randn(40, 2) node_idx = torch.randint(0, 3, torch.Size([40])) edge_idx = torch.randint(0, 3, torch.Size([20])) global_model = AggregatingGlobalBlock(MLP(8, 16, 10), Aggregator("mean"), Aggregator("mean")) out = global_model( global_attr=global_attr, node_attr=node_attr, edge_attr=edge_attr, edges=edges, node_idx=node_idx, edge_idx=edge_idx, ) for p in global_model.parameters(): assert p.requires_grad print(list(global_model.parameters()))
def _init_core(self): edge_layers = [self.config["sizes"]["latent"]["edge"] ] * self.config["sizes"]["latent"]["edge_depth"] node_layers = [self.config["sizes"]["latent"]["node"] ] * self.config["sizes"]["latent"]["node_depth"] global_layers = [self.config["sizes"]["latent"]["global"] ] * self.config["sizes"]["latent"]["global_depth"] return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *edge_layers, dropout=self.config["dropout"], layer_norm=True), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *node_layers, dropout=self.config["dropout"], layer_norm=True), ), Flex(MultiAggregator)( Flex.d(), self.config["node_block_aggregator"], activation=self.config["aggregator_activation"], ), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *global_layers, dropout=self.config["dropout"], layer_norm=True), ), edge_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_edge_aggregator"], activation=self.config["aggregator_activation"], ), node_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_node_aggregator"], activation=self.config["aggregator_activation"], ), ), pass_global_to_edge=self.config["pass_global_to_edge"], pass_global_to_node=self.config["pass_global_to_node"], )
def create_graph_core_multi_agg(pass_global_to_edge: bool, pass_global_to_node: bool): agg = lambda: Flex(MultiAggregator)(Flex.d(), ["add", "mean", "max", "min"]) return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=agg(), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=agg(), node_aggregator=agg(), ), pass_global_to_edge=pass_global_to_edge, pass_global_to_node=pass_global_to_node, )
def __init__( self, latent_sizes=(128, 128, 1), output_sizes=(1, 1, 1), depths=(1, 1, 1), layer_norm: bool = True, dropout: float = None, pass_global_to_edge: bool = True, pass_global_to_node: bool = True, ): super().__init__() self.config = { "latent_size": { "node": latent_sizes[1], "edge": latent_sizes[0], "global": latent_sizes[2], "core_node_block_depth": depths[0], "core_edge_block_depth": depths[1], "core_global_block_depth": depths[2], }, "output_size": { "edge": output_sizes[0], "node": output_sizes[1], "global": output_sizes[2], }, "node_block_aggregator": "add", "global_block_to_node_aggregator": "add", "global_block_to_edge_aggregator": "add", "pass_global_to_edge": pass_global_to_edge, "pass_global_to_node": pass_global_to_node, } def mlp(*layer_sizes): return Flex(MLP)(Flex.d(), *layer_sizes, layer_norm=layer_norm, dropout=dropout) self.encoder = GraphEncoder( EdgeBlock(mlp(latent_sizes[0])), NodeBlock(mlp(latent_sizes[1])), GlobalBlock(mlp(latent_sizes[2])), ) edge_layers = [self.config["latent_size"]["edge"] ] * self.config["latent_size"]["core_edge_block_depth"] node_layers = [self.config["latent_size"]["node"] ] * self.config["latent_size"]["core_node_block_depth"] global_layers = [ self.config["latent_size"]["global"] ] * self.config["latent_size"]["core_global_block_depth"] self.core = GraphCore( AggregatingEdgeBlock(mlp(*edge_layers)), AggregatingNodeBlock( mlp(*node_layers), Aggregator(self.config["node_block_aggregator"])), AggregatingGlobalBlock( mlp(*global_layers), edge_aggregator=Aggregator( self.config["global_block_to_edge_aggregator"]), node_aggregator=Aggregator( self.config["global_block_to_node_aggregator"]), ), pass_global_to_edge=self.config["pass_global_to_edge"], pass_global_to_node=self.config["pass_global_to_node"], ) self.decoder = GraphEncoder( EdgeBlock(mlp(latent_sizes[0])), NodeBlock(mlp(latent_sizes[1])), GlobalBlock(mlp(latent_sizes[2])), ) self.output_transform = GraphEncoder( EdgeBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[0])), NodeBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[1])), GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), output_sizes[2])), )
def __init__( self, latent_sizes=(128, 128, 1), depths=(1, 1, 1), dropout: float = None, pass_global_to_edge: bool = True, pass_global_to_node: bool = True, edge_to_node_aggregators=tuple(["add", "max", "mean", "min"]), edge_to_global_aggregators=tuple(["add", "max", "mean", "min"]), node_to_global_aggregators=tuple(["add", "max", "mean", "min"]), aggregator_activation=defaults.activation, ): super().__init__() self.config = { "latent_size": { "node": latent_sizes[1], "edge": latent_sizes[0], "global": latent_sizes[2], "core_node_block_depth": depths[0], "core_edge_block_depth": depths[1], "core_global_block_depth": depths[2], }, "node_block_aggregator": edge_to_node_aggregators, "global_block_to_node_aggregator": node_to_global_aggregators, "global_block_to_edge_aggregator": edge_to_global_aggregators, "aggregator_activation": aggregator_activation, "pass_global_to_edge": pass_global_to_edge, "pass_global_to_node": pass_global_to_node, } self.encoder = GraphEncoder( EdgeBlock(Flex(MLP)(Flex.d(), latent_sizes[0], dropout=dropout)), NodeBlock(Flex(MLP)(Flex.d(), latent_sizes[1], dropout=dropout)), GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2], dropout=dropout)), ) edge_layers = [self.config["latent_size"]["edge"] ] * self.config["latent_size"]["core_edge_block_depth"] node_layers = [self.config["latent_size"]["node"] ] * self.config["latent_size"]["core_node_block_depth"] global_layers = [ self.config["latent_size"]["global"] ] * self.config["latent_size"]["core_global_block_depth"] self.core = GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *edge_layers, dropout=dropout, layer_norm=True), # Flex(torch.nn.Linear)(Flex.d(), edge_layers[-1]) )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *node_layers, dropout=dropout, layer_norm=True), # Flex(torch.nn.Linear)(Flex.d(), node_layers[-1]) ), Flex(MultiAggregator)( Flex.d(), self.config["node_block_aggregator"], activation=self.config["aggregator_activation"], ), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *global_layers, dropout=dropout, layer_norm=True), # Flex(torch.nn.Linear)(Flex.d(), global_layers[-1]) ), edge_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_edge_aggregator"], activation=self.config["aggregator_activation"], ), node_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_node_aggregator"], activation=self.config["aggregator_activation"], ), ), pass_global_to_edge=self.config["pass_global_to_edge"], pass_global_to_node=self.config["pass_global_to_node"], ) self.decoder = GraphEncoder( EdgeBlock( Flex(MLP)(Flex.d(), latent_sizes[0], latent_sizes[0], dropout=dropout)), NodeBlock( Flex(MLP)(Flex.d(), latent_sizes[1], latent_sizes[1], dropout=dropout)), GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2])), ) self.output_transform = GraphEncoder( EdgeBlock(Flex(torch.nn.Linear)(Flex.d(), 1)), NodeBlock(Flex(torch.nn.Linear)(Flex.d(), 1)), GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), 1)), )