def __init__(self, embedding_dim, feature_dims, depth=3): super(RecurrentEmbeddingModelCore, self).__init__() self.embedding_dim = embedding_dim self.node_embedding = message_passing.DenseSparsePreEmbedding( target.TargetType, { k.name: torch.nn.Sequential( numerical_features.NumericalFeatureEncoding(fd.values(), embedding_dim), numerical_features.NumericalFeaturesEmbedding(embedding_dim) ) for k, fd in feature_dims.items() }, len(target.NODE_TYPES), embedding_dim) self.global_entity_embedding = torch.nn.GRU( input_size=embedding_dim, hidden_size=embedding_dim, num_layers=3) self.edge_embedding = torch.nn.Embedding(len(target.EDGE_TYPES), embedding_dim) self.message_passing = sg_nn.MessagePassingNetwork( depth, torch.nn.GRUCell(embedding_dim, embedding_dim), sg_nn.ConcatenateLinear(embedding_dim, embedding_dim, embedding_dim)) self.graph_post_embedding = message_passing.GraphPostEmbedding(embedding_dim) self.merge_global_embedding = sg_nn.ConcatenateLinear(embedding_dim, 2 * embedding_dim, embedding_dim)
def make_graph_model(hidden_size, feature_dimensions, message_passing_rounds=3, readout_edge_features=True, readout_entity_features=True, readin_edge_features=None, readin_entity_features=None): """Create a graph model using the default inner architectural choices. The main graph model architecture is described as a large number of specific networks assembled together. This function chooses the architecture of all these networks from a couple of parameters to be specified. Parameters ---------- hidden_size : int An integer parameter controlling the width of layers in the network. feature_dimensions : Dict[TargetType, List[int]] A dictionary describing the size of attribute features for each target type. message_passing_rounds : int An integer parameter controlling the number of message passing rounds in the model core. readout_edge_features : bool If true, indicates that the model should predict and output logits for edge features. Otherwise, edge feature prediction tasks are ignored. readout_entity_features : bool If true, indicates that the model should predict and output logits for entity features. Otherwise, entity feature prediction tasks are ignored. readin_edge_features : bool, optional If True, indicates that the model should use edge features as inputs. Otherwise, ignores input edge features (and only considers edge labels). If None, set to the same value as `readout_edge_features`. Note that setting this to a different value than `readout_edge_features` will prevent meaningful generation from the model. readin_entity_features : bool, optional If True, indicates that the model should use entity features as inputs. Otherwise, ignores input entity features (and only considers entity labels). If None, set to the same value as `readout_entity_features`. Returns ------- GraphModel A new `GraphModel` instance with the default architecture. """ if readin_edge_features is None: readin_edge_features = readout_edge_features if readin_entity_features is None: readin_entity_features = readout_entity_features # Build encoders and decoders for edge features if readout_edge_features or readin_edge_features: edge_feature_dimensions = collections.OrderedDict( (t, feature_dimensions[t]) for t in target.TargetType.numerical_edge_types()) edge_embeddings, edge_readouts = numerical_features.make_embedding_and_readout( hidden_size, edge_feature_dimensions, numerical_features.edge_decoder_initial_input) else: edge_readouts = None if readin_edge_features: edge_embedding = message_passing.DenseSparsePreEmbedding( target.TargetType, edge_embeddings, len(target.EDGE_TYPES), hidden_size) else: edge_embedding = message_passing.DenseOnlyEmbedding( len(target.EDGE_TYPES), hidden_size) # Build encoders and decoders for node features if readout_entity_features or readin_entity_features: node_feature_dimensions = collections.OrderedDict( (t, feature_dimensions[t]) for t in target.TargetType.numerical_node_types()) node_embeddings, node_readouts = numerical_features.make_embedding_and_readout( hidden_size, node_feature_dimensions, numerical_features.entity_decoder_initial_input) else: node_readouts = None if readin_entity_features: node_embedding = message_passing.DenseSparsePreEmbedding( target.TargetType, node_embeddings, len(target.NODE_TYPES), hidden_size) else: node_embedding = message_passing.DenseOnlyEmbedding( len(target.NODE_TYPES), hidden_size) # Build main model core model_core = message_passing.GraphModelCore( sg_nn.MessagePassingNetwork( message_passing_rounds, torch.nn.GRUCell(hidden_size, hidden_size), sg_nn.ConcatenateLinear(hidden_size, hidden_size, hidden_size)), node_embedding, edge_embedding, message_passing.GraphPostEmbedding(hidden_size, hidden_size), ) return GraphModel( model_core, entity_label=sg_nn.Sequential( torch.nn.Linear(hidden_size, hidden_size // 2), torch.nn.ReLU(), torch.nn.Linear(hidden_size // 2, hidden_size // 2), torch.nn.ReLU(), torch.nn.Linear(hidden_size // 2, len(target.NODE_TYPES_PREDICTED))), entity_feature_readout=node_readouts, edge_post_embedding=sg_nn.ConcatenateLinear(hidden_size, hidden_size, hidden_size), edge_label=sg_nn.Sequential( sg_nn.ConcatenateLinear(hidden_size, hidden_size, hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, len(target.EDGE_TYPES_PREDICTED))), edge_feature_readout=edge_readouts, edge_partner=EdgePartnerNetwork( torch.nn.Sequential(torch.nn.Linear(3 * hidden_size, hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, 1))))