Esempio n. 1
0
    def __init__(self, embedding_dim, feature_dims, depth=3):
        super(RecurrentEmbeddingModelCore, self).__init__()

        self.embedding_dim = embedding_dim

        self.node_embedding = message_passing.DenseSparsePreEmbedding(
            target.TargetType, {
                k.name: torch.nn.Sequential(
                    numerical_features.NumericalFeatureEncoding(fd.values(), embedding_dim),
                    numerical_features.NumericalFeaturesEmbedding(embedding_dim)
                )
                for k, fd in feature_dims.items()
            },
            len(target.NODE_TYPES), embedding_dim)

        self.global_entity_embedding = torch.nn.GRU(
            input_size=embedding_dim,
            hidden_size=embedding_dim,
            num_layers=3)

        self.edge_embedding = torch.nn.Embedding(len(target.EDGE_TYPES), embedding_dim)

        self.message_passing = sg_nn.MessagePassingNetwork(
            depth, torch.nn.GRUCell(embedding_dim, embedding_dim),
            sg_nn.ConcatenateLinear(embedding_dim, embedding_dim, embedding_dim))

        self.graph_post_embedding = message_passing.GraphPostEmbedding(embedding_dim)

        self.merge_global_embedding = sg_nn.ConcatenateLinear(embedding_dim, 2 * embedding_dim, embedding_dim)
Esempio n. 2
0
    def __init__(self,
                 target_type,
                 feature_embeddings,
                 fixed_embedding_cardinality,
                 fixed_embedding_dim,
                 sparse_embedding_dim=None,
                 embedding_dim=None):
        """Initializes a new DenseSparsePreEmbedding module.

        Parameters
        ----------
        target_type : enum
            The underlying enumeration indicating target types
        feature_embeddings : dict of modules
            A dictionary of embeddings for each of the sparse feature types.
        fixed_embedding_cardinality : int
            The number of classes in the fixed (dense) embedding layer.
        fixed_embedding_dim : int
            The dimension of the embedding for the fixed layer.
        sparse_embedding_dim : int, optional
            The dimension of the sparse embeddings. If None, assumed to be the same as the dense embedding.
        embedding_dim : int, optional
            The outpu dimension of the embedding. If None, assumed to be the same as the dense embedding.
        """
        super(DenseSparsePreEmbedding, self).__init__()
        sparse_embedding_dim = sparse_embedding_dim or fixed_embedding_dim
        embedding_dim = embedding_dim or fixed_embedding_dim

        self.target_type = target_type
        self.feature_embeddings = torch.nn.ModuleDict(feature_embeddings)
        self.sparse_embedding_dim = sparse_embedding_dim
        self.fixed_embedding_dim = fixed_embedding_dim
        self.fixed_embedding = torch.nn.Embedding(fixed_embedding_cardinality,
                                                  fixed_embedding_dim)
        self.dense_merge = sg_nn.ConcatenateLinear(fixed_embedding_dim,
                                                   sparse_embedding_dim,
                                                   embedding_dim)
Esempio n. 3
0
def make_graph_model(hidden_size,
                     feature_dimensions,
                     message_passing_rounds=3,
                     readout_edge_features=True,
                     readout_entity_features=True,
                     readin_edge_features=None,
                     readin_entity_features=None):
    """Create a graph model using the default inner architectural choices.

    The main graph model architecture is described as a large number of specific
    networks assembled together. This function chooses the architecture of all these
    networks from a couple of parameters to be specified.

    Parameters
    ----------
    hidden_size : int
        An integer parameter controlling the width of layers in the network.
    feature_dimensions : Dict[TargetType, List[int]]
        A dictionary describing the size of attribute features for each target type.
    message_passing_rounds : int
        An integer parameter controlling the number of message passing rounds in the model core.
    readout_edge_features : bool
        If true, indicates that the model should predict and output logits for edge features.
        Otherwise, edge feature prediction tasks are ignored.
    readout_entity_features : bool
        If true, indicates that the model should predict and output logits for entity features.
        Otherwise, entity feature prediction tasks are ignored.
    readin_edge_features : bool, optional
        If True, indicates that the model should use edge features as inputs. Otherwise, ignores
        input edge features (and only considers edge labels). If None, set to the same value as
        `readout_edge_features`. Note that setting this to a different value than `readout_edge_features`
        will prevent meaningful generation from the model.
    readin_entity_features : bool, optional
        If True, indicates that the model should use entity features as inputs. Otherwise, ignores
        input entity features (and only considers entity labels). If None, set to the same value
        as `readout_entity_features`.

    Returns
    -------
    GraphModel
        A new `GraphModel` instance with the default architecture.
    """
    if readin_edge_features is None:
        readin_edge_features = readout_edge_features

    if readin_entity_features is None:
        readin_entity_features = readout_entity_features

    # Build encoders and decoders for edge features
    if readout_edge_features or readin_edge_features:
        edge_feature_dimensions = collections.OrderedDict(
            (t, feature_dimensions[t])
            for t in target.TargetType.numerical_edge_types())
        edge_embeddings, edge_readouts = numerical_features.make_embedding_and_readout(
            hidden_size, edge_feature_dimensions,
            numerical_features.edge_decoder_initial_input)
    else:
        edge_readouts = None

    if readin_edge_features:
        edge_embedding = message_passing.DenseSparsePreEmbedding(
            target.TargetType, edge_embeddings, len(target.EDGE_TYPES),
            hidden_size)
    else:
        edge_embedding = message_passing.DenseOnlyEmbedding(
            len(target.EDGE_TYPES), hidden_size)

    # Build encoders and decoders for node features
    if readout_entity_features or readin_entity_features:
        node_feature_dimensions = collections.OrderedDict(
            (t, feature_dimensions[t])
            for t in target.TargetType.numerical_node_types())
        node_embeddings, node_readouts = numerical_features.make_embedding_and_readout(
            hidden_size, node_feature_dimensions,
            numerical_features.entity_decoder_initial_input)
    else:
        node_readouts = None

    if readin_entity_features:
        node_embedding = message_passing.DenseSparsePreEmbedding(
            target.TargetType, node_embeddings, len(target.NODE_TYPES),
            hidden_size)
    else:
        node_embedding = message_passing.DenseOnlyEmbedding(
            len(target.NODE_TYPES), hidden_size)

    # Build main model core
    model_core = message_passing.GraphModelCore(
        sg_nn.MessagePassingNetwork(
            message_passing_rounds, torch.nn.GRUCell(hidden_size, hidden_size),
            sg_nn.ConcatenateLinear(hidden_size, hidden_size, hidden_size)),
        node_embedding,
        edge_embedding,
        message_passing.GraphPostEmbedding(hidden_size, hidden_size),
    )

    return GraphModel(
        model_core,
        entity_label=sg_nn.Sequential(
            torch.nn.Linear(hidden_size, hidden_size // 2), torch.nn.ReLU(),
            torch.nn.Linear(hidden_size // 2, hidden_size // 2),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size // 2,
                            len(target.NODE_TYPES_PREDICTED))),
        entity_feature_readout=node_readouts,
        edge_post_embedding=sg_nn.ConcatenateLinear(hidden_size, hidden_size,
                                                    hidden_size),
        edge_label=sg_nn.Sequential(
            sg_nn.ConcatenateLinear(hidden_size, hidden_size, hidden_size),
            torch.nn.ReLU(), torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, len(target.EDGE_TYPES_PREDICTED))),
        edge_feature_readout=edge_readouts,
        edge_partner=EdgePartnerNetwork(
            torch.nn.Sequential(torch.nn.Linear(3 * hidden_size, hidden_size),
                                torch.nn.ReLU(),
                                torch.nn.Linear(hidden_size, 1))))
def edge_decoder_initial_input(embedding_size):
    """Initial input function for edge readouts."""
    return sg_nn.Sequential(
        sg_nn.ConcatenateLinear(embedding_size, embedding_size,
                                embedding_size), torch.nn.ReLU(),
        torch.nn.Linear(embedding_size, embedding_size))