def init_core(self): return GraphCore( edge_block=EdgeBlockCore( layers=self.config.core.edge.layers, dropout=self.config.core.edge.dropout, layer_norm=self.config.core.edge.layer_norm, ), node_block=NodeBlockCore( layers=self.config.core.node.layers, dropout=self.config.core.node.dropout, aggregator=self.config.edge_block_to_node_aggregators, aggregator_activation=self.config.get_activation( self.config.aggregator_activation ), layer_norm=self.config.core.node.layer_norm, ), global_block=GlobalBlockCore( layers=self.config.core.glob.layers, dropout=self.config.core.glob.dropout, edge_aggregator=self.config.global_block_to_edge_aggregators, node_aggregator=self.config.global_block_to_node_aggregators, layer_norm=self.config.core.glob.layer_norm, aggregator_activation=self.config.get_activation( self.config.aggregator_activation ), ), pass_global_to_node=self.config.pass_global_to_node, pass_global_to_edge=self.config.pass_global_to_edge, )
def create_graph_core(pass_global_to_edge: bool, pass_global_to_node: bool): return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=Aggregator("add"), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=False), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=Aggregator("add"), node_aggregator=Aggregator("add"), ), pass_global_to_edge=pass_global_to_edge, pass_global_to_node=pass_global_to_node, )
def _init_core(self): edge_layers = [self.config["sizes"]["latent"]["edge"] ] * self.config["sizes"]["latent"]["edge_depth"] node_layers = [self.config["sizes"]["latent"]["node"] ] * self.config["sizes"]["latent"]["node_depth"] global_layers = [self.config["sizes"]["latent"]["global"] ] * self.config["sizes"]["latent"]["global_depth"] return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *edge_layers, dropout=self.config["dropout"], layer_norm=True), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *node_layers, dropout=self.config["dropout"], layer_norm=True), ), Flex(MultiAggregator)( Flex.d(), self.config["node_block_aggregator"], activation=self.config["aggregator_activation"], ), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *global_layers, dropout=self.config["dropout"], layer_norm=True), ), edge_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_edge_aggregator"], activation=self.config["aggregator_activation"], ), node_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_node_aggregator"], activation=self.config["aggregator_activation"], ), ), pass_global_to_edge=self.config["pass_global_to_edge"], pass_global_to_node=self.config["pass_global_to_node"], )
def create_graph_core_multi_agg(pass_global_to_edge: bool, pass_global_to_node: bool): agg = lambda: Flex(MultiAggregator)(Flex.d(), ["add", "mean", "max", "min"]) return GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=agg(), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), 5, 5, layer_norm=True, activation=torch.nn.LeakyReLU), Flex(torch.nn.Linear)(Flex.d(), 1), ), edge_aggregator=agg(), node_aggregator=agg(), ), pass_global_to_edge=pass_global_to_edge, pass_global_to_node=pass_global_to_node, )
def __init__( self, latent_sizes=(128, 128, 1), depths=(1, 1, 1), dropout: float = None, pass_global_to_edge: bool = True, pass_global_to_node: bool = True, edge_to_node_aggregators=tuple(["add", "max", "mean", "min"]), edge_to_global_aggregators=tuple(["add", "max", "mean", "min"]), node_to_global_aggregators=tuple(["add", "max", "mean", "min"]), aggregator_activation=defaults.activation, ): super().__init__() self.config = { "latent_size": { "node": latent_sizes[1], "edge": latent_sizes[0], "global": latent_sizes[2], "core_node_block_depth": depths[0], "core_edge_block_depth": depths[1], "core_global_block_depth": depths[2], }, "node_block_aggregator": edge_to_node_aggregators, "global_block_to_node_aggregator": node_to_global_aggregators, "global_block_to_edge_aggregator": edge_to_global_aggregators, "aggregator_activation": aggregator_activation, "pass_global_to_edge": pass_global_to_edge, "pass_global_to_node": pass_global_to_node, } self.encoder = GraphEncoder( EdgeBlock(Flex(MLP)(Flex.d(), latent_sizes[0], dropout=dropout)), NodeBlock(Flex(MLP)(Flex.d(), latent_sizes[1], dropout=dropout)), GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2], dropout=dropout)), ) edge_layers = [self.config["latent_size"]["edge"] ] * self.config["latent_size"]["core_edge_block_depth"] node_layers = [self.config["latent_size"]["node"] ] * self.config["latent_size"]["core_node_block_depth"] global_layers = [ self.config["latent_size"]["global"] ] * self.config["latent_size"]["core_global_block_depth"] self.core = GraphCore( AggregatingEdgeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *edge_layers, dropout=dropout, layer_norm=True), # Flex(torch.nn.Linear)(Flex.d(), edge_layers[-1]) )), AggregatingNodeBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *node_layers, dropout=dropout, layer_norm=True), # Flex(torch.nn.Linear)(Flex.d(), node_layers[-1]) ), Flex(MultiAggregator)( Flex.d(), self.config["node_block_aggregator"], activation=self.config["aggregator_activation"], ), ), AggregatingGlobalBlock( torch.nn.Sequential( Flex(MLP)(Flex.d(), *global_layers, dropout=dropout, layer_norm=True), # Flex(torch.nn.Linear)(Flex.d(), global_layers[-1]) ), edge_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_edge_aggregator"], activation=self.config["aggregator_activation"], ), node_aggregator=Flex(MultiAggregator)( Flex.d(), self.config["global_block_to_node_aggregator"], activation=self.config["aggregator_activation"], ), ), pass_global_to_edge=self.config["pass_global_to_edge"], pass_global_to_node=self.config["pass_global_to_node"], ) self.decoder = GraphEncoder( EdgeBlock( Flex(MLP)(Flex.d(), latent_sizes[0], latent_sizes[0], dropout=dropout)), NodeBlock( Flex(MLP)(Flex.d(), latent_sizes[1], latent_sizes[1], dropout=dropout)), GlobalBlock(Flex(MLP)(Flex.d(), latent_sizes[2])), ) self.output_transform = GraphEncoder( EdgeBlock(Flex(torch.nn.Linear)(Flex.d(), 1)), NodeBlock(Flex(torch.nn.Linear)(Flex.d(), 1)), GlobalBlock(Flex(torch.nn.Linear)(Flex.d(), 1)), )