示例#1
0
    def __init__(self, embedding_dim, feature_dims, depth=3):
        super(RecurrentEmbeddingModelCore, self).__init__()

        self.embedding_dim = embedding_dim

        self.node_embedding = message_passing.DenseSparsePreEmbedding(
            target.TargetType, {
                k.name: torch.nn.Sequential(
                    numerical_features.NumericalFeatureEncoding(fd.values(), embedding_dim),
                    numerical_features.NumericalFeaturesEmbedding(embedding_dim)
                )
                for k, fd in feature_dims.items()
            },
            len(target.NODE_TYPES), embedding_dim)

        self.global_entity_embedding = torch.nn.GRU(
            input_size=embedding_dim,
            hidden_size=embedding_dim,
            num_layers=3)

        self.edge_embedding = torch.nn.Embedding(len(target.EDGE_TYPES), embedding_dim)

        self.message_passing = sg_nn.MessagePassingNetwork(
            depth, torch.nn.GRUCell(embedding_dim, embedding_dim),
            sg_nn.ConcatenateLinear(embedding_dim, embedding_dim, embedding_dim))

        self.graph_post_embedding = message_passing.GraphPostEmbedding(embedding_dim)

        self.merge_global_embedding = sg_nn.ConcatenateLinear(embedding_dim, 2 * embedding_dim, embedding_dim)
示例#2
0
def make_graph_model(hidden_size,
                     feature_dimensions,
                     message_passing_rounds=3,
                     readout_edge_features=True,
                     readout_entity_features=True,
                     readin_edge_features=None,
                     readin_entity_features=None):
    """Create a graph model using the default inner architectural choices.

    The main graph model architecture is described as a large number of specific
    networks assembled together. This function chooses the architecture of all these
    networks from a couple of parameters to be specified.

    Parameters
    ----------
    hidden_size : int
        An integer parameter controlling the width of layers in the network.
    feature_dimensions : Dict[TargetType, List[int]]
        A dictionary describing the size of attribute features for each target type.
    message_passing_rounds : int
        An integer parameter controlling the number of message passing rounds in the model core.
    readout_edge_features : bool
        If true, indicates that the model should predict and output logits for edge features.
        Otherwise, edge feature prediction tasks are ignored.
    readout_entity_features : bool
        If true, indicates that the model should predict and output logits for entity features.
        Otherwise, entity feature prediction tasks are ignored.
    readin_edge_features : bool, optional
        If True, indicates that the model should use edge features as inputs. Otherwise, ignores
        input edge features (and only considers edge labels). If None, set to the same value as
        `readout_edge_features`. Note that setting this to a different value than `readout_edge_features`
        will prevent meaningful generation from the model.
    readin_entity_features : bool, optional
        If True, indicates that the model should use entity features as inputs. Otherwise, ignores
        input entity features (and only considers entity labels). If None, set to the same value
        as `readout_entity_features`.

    Returns
    -------
    GraphModel
        A new `GraphModel` instance with the default architecture.
    """
    if readin_edge_features is None:
        readin_edge_features = readout_edge_features

    if readin_entity_features is None:
        readin_entity_features = readout_entity_features

    # Build encoders and decoders for edge features
    if readout_edge_features or readin_edge_features:
        edge_feature_dimensions = collections.OrderedDict(
            (t, feature_dimensions[t])
            for t in target.TargetType.numerical_edge_types())
        edge_embeddings, edge_readouts = numerical_features.make_embedding_and_readout(
            hidden_size, edge_feature_dimensions,
            numerical_features.edge_decoder_initial_input)
    else:
        edge_readouts = None

    if readin_edge_features:
        edge_embedding = message_passing.DenseSparsePreEmbedding(
            target.TargetType, edge_embeddings, len(target.EDGE_TYPES),
            hidden_size)
    else:
        edge_embedding = message_passing.DenseOnlyEmbedding(
            len(target.EDGE_TYPES), hidden_size)

    # Build encoders and decoders for node features
    if readout_entity_features or readin_entity_features:
        node_feature_dimensions = collections.OrderedDict(
            (t, feature_dimensions[t])
            for t in target.TargetType.numerical_node_types())
        node_embeddings, node_readouts = numerical_features.make_embedding_and_readout(
            hidden_size, node_feature_dimensions,
            numerical_features.entity_decoder_initial_input)
    else:
        node_readouts = None

    if readin_entity_features:
        node_embedding = message_passing.DenseSparsePreEmbedding(
            target.TargetType, node_embeddings, len(target.NODE_TYPES),
            hidden_size)
    else:
        node_embedding = message_passing.DenseOnlyEmbedding(
            len(target.NODE_TYPES), hidden_size)

    # Build main model core
    model_core = message_passing.GraphModelCore(
        sg_nn.MessagePassingNetwork(
            message_passing_rounds, torch.nn.GRUCell(hidden_size, hidden_size),
            sg_nn.ConcatenateLinear(hidden_size, hidden_size, hidden_size)),
        node_embedding,
        edge_embedding,
        message_passing.GraphPostEmbedding(hidden_size, hidden_size),
    )

    return GraphModel(
        model_core,
        entity_label=sg_nn.Sequential(
            torch.nn.Linear(hidden_size, hidden_size // 2), torch.nn.ReLU(),
            torch.nn.Linear(hidden_size // 2, hidden_size // 2),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size // 2,
                            len(target.NODE_TYPES_PREDICTED))),
        entity_feature_readout=node_readouts,
        edge_post_embedding=sg_nn.ConcatenateLinear(hidden_size, hidden_size,
                                                    hidden_size),
        edge_label=sg_nn.Sequential(
            sg_nn.ConcatenateLinear(hidden_size, hidden_size, hidden_size),
            torch.nn.ReLU(), torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, len(target.EDGE_TYPES_PREDICTED))),
        edge_feature_readout=edge_readouts,
        edge_partner=EdgePartnerNetwork(
            torch.nn.Sequential(torch.nn.Linear(3 * hidden_size, hidden_size),
                                torch.nn.ReLU(),
                                torch.nn.Linear(hidden_size, 1))))