def __init__(self, embedding_dim, feature_dims, depth=3): super(RecurrentEmbeddingModelCore, self).__init__() self.embedding_dim = embedding_dim self.node_embedding = message_passing.DenseSparsePreEmbedding( target.TargetType, { k.name: torch.nn.Sequential( numerical_features.NumericalFeatureEncoding(fd.values(), embedding_dim), numerical_features.NumericalFeaturesEmbedding(embedding_dim) ) for k, fd in feature_dims.items() }, len(target.NODE_TYPES), embedding_dim) self.global_entity_embedding = torch.nn.GRU( input_size=embedding_dim, hidden_size=embedding_dim, num_layers=3) self.edge_embedding = torch.nn.Embedding(len(target.EDGE_TYPES), embedding_dim) self.message_passing = sg_nn.MessagePassingNetwork( depth, torch.nn.GRUCell(embedding_dim, embedding_dim), sg_nn.ConcatenateLinear(embedding_dim, embedding_dim, embedding_dim)) self.graph_post_embedding = message_passing.GraphPostEmbedding(embedding_dim) self.merge_global_embedding = sg_nn.ConcatenateLinear(embedding_dim, 2 * embedding_dim, embedding_dim)
def __init__(self, target_type, feature_embeddings, fixed_embedding_cardinality, fixed_embedding_dim, sparse_embedding_dim=None, embedding_dim=None): """Initializes a new DenseSparsePreEmbedding module. Parameters ---------- target_type : enum The underlying enumeration indicating target types feature_embeddings : dict of modules A dictionary of embeddings for each of the sparse feature types. fixed_embedding_cardinality : int The number of classes in the fixed (dense) embedding layer. fixed_embedding_dim : int The dimension of the embedding for the fixed layer. sparse_embedding_dim : int, optional The dimension of the sparse embeddings. If None, assumed to be the same as the dense embedding. embedding_dim : int, optional The outpu dimension of the embedding. If None, assumed to be the same as the dense embedding. """ super(DenseSparsePreEmbedding, self).__init__() sparse_embedding_dim = sparse_embedding_dim or fixed_embedding_dim embedding_dim = embedding_dim or fixed_embedding_dim self.target_type = target_type self.feature_embeddings = torch.nn.ModuleDict(feature_embeddings) self.sparse_embedding_dim = sparse_embedding_dim self.fixed_embedding_dim = fixed_embedding_dim self.fixed_embedding = torch.nn.Embedding(fixed_embedding_cardinality, fixed_embedding_dim) self.dense_merge = sg_nn.ConcatenateLinear(fixed_embedding_dim, sparse_embedding_dim, embedding_dim)
def make_graph_model(hidden_size, feature_dimensions, message_passing_rounds=3, readout_edge_features=True, readout_entity_features=True, readin_edge_features=None, readin_entity_features=None): """Create a graph model using the default inner architectural choices. The main graph model architecture is described as a large number of specific networks assembled together. This function chooses the architecture of all these networks from a couple of parameters to be specified. Parameters ---------- hidden_size : int An integer parameter controlling the width of layers in the network. feature_dimensions : Dict[TargetType, List[int]] A dictionary describing the size of attribute features for each target type. message_passing_rounds : int An integer parameter controlling the number of message passing rounds in the model core. readout_edge_features : bool If true, indicates that the model should predict and output logits for edge features. Otherwise, edge feature prediction tasks are ignored. readout_entity_features : bool If true, indicates that the model should predict and output logits for entity features. Otherwise, entity feature prediction tasks are ignored. readin_edge_features : bool, optional If True, indicates that the model should use edge features as inputs. Otherwise, ignores input edge features (and only considers edge labels). If None, set to the same value as `readout_edge_features`. Note that setting this to a different value than `readout_edge_features` will prevent meaningful generation from the model. readin_entity_features : bool, optional If True, indicates that the model should use entity features as inputs. Otherwise, ignores input entity features (and only considers entity labels). If None, set to the same value as `readout_entity_features`. Returns ------- GraphModel A new `GraphModel` instance with the default architecture. """ if readin_edge_features is None: readin_edge_features = readout_edge_features if readin_entity_features is None: readin_entity_features = readout_entity_features # Build encoders and decoders for edge features if readout_edge_features or readin_edge_features: edge_feature_dimensions = collections.OrderedDict( (t, feature_dimensions[t]) for t in target.TargetType.numerical_edge_types()) edge_embeddings, edge_readouts = numerical_features.make_embedding_and_readout( hidden_size, edge_feature_dimensions, numerical_features.edge_decoder_initial_input) else: edge_readouts = None if readin_edge_features: edge_embedding = message_passing.DenseSparsePreEmbedding( target.TargetType, edge_embeddings, len(target.EDGE_TYPES), hidden_size) else: edge_embedding = message_passing.DenseOnlyEmbedding( len(target.EDGE_TYPES), hidden_size) # Build encoders and decoders for node features if readout_entity_features or readin_entity_features: node_feature_dimensions = collections.OrderedDict( (t, feature_dimensions[t]) for t in target.TargetType.numerical_node_types()) node_embeddings, node_readouts = numerical_features.make_embedding_and_readout( hidden_size, node_feature_dimensions, numerical_features.entity_decoder_initial_input) else: node_readouts = None if readin_entity_features: node_embedding = message_passing.DenseSparsePreEmbedding( target.TargetType, node_embeddings, len(target.NODE_TYPES), hidden_size) else: node_embedding = message_passing.DenseOnlyEmbedding( len(target.NODE_TYPES), hidden_size) # Build main model core model_core = message_passing.GraphModelCore( sg_nn.MessagePassingNetwork( message_passing_rounds, torch.nn.GRUCell(hidden_size, hidden_size), sg_nn.ConcatenateLinear(hidden_size, hidden_size, hidden_size)), node_embedding, edge_embedding, message_passing.GraphPostEmbedding(hidden_size, hidden_size), ) return GraphModel( model_core, entity_label=sg_nn.Sequential( torch.nn.Linear(hidden_size, hidden_size // 2), torch.nn.ReLU(), torch.nn.Linear(hidden_size // 2, hidden_size // 2), torch.nn.ReLU(), torch.nn.Linear(hidden_size // 2, len(target.NODE_TYPES_PREDICTED))), entity_feature_readout=node_readouts, edge_post_embedding=sg_nn.ConcatenateLinear(hidden_size, hidden_size, hidden_size), edge_label=sg_nn.Sequential( sg_nn.ConcatenateLinear(hidden_size, hidden_size, hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, len(target.EDGE_TYPES_PREDICTED))), edge_feature_readout=edge_readouts, edge_partner=EdgePartnerNetwork( torch.nn.Sequential(torch.nn.Linear(3 * hidden_size, hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, 1))))
def edge_decoder_initial_input(embedding_size): """Initial input function for edge readouts.""" return sg_nn.Sequential( sg_nn.ConcatenateLinear(embedding_size, embedding_size, embedding_size), torch.nn.ReLU(), torch.nn.Linear(embedding_size, embedding_size))