Beispiel #1
0
    def __init__(self, dim_node_features, dim_edge_features, dim_target,
                 predictor_class, config):
        """
        Initializes the model.
        :param dim_node_features: arbitrary object holding node feature information
        :param dim_edge_features: arbitrary object holding edge feature information
        :param dim_target: arbitrary object holding target information
        :param predictor_class: the class of the predictor that will classify node/graph embeddings produced by this DGN
        :param config: the configuration dictionary to extract further hyper-parameters
        """
        super().__init__()

        num_layers = config['num_layers']
        dim_embedding = config['dim_embedding']
        self.aggregation = config['aggregation']  # can be mean or max

        if self.aggregation == 'max':
            self.fc_max = nn.Linear(dim_embedding, dim_embedding)

        self.predictor = predictor_class(dim_node_features=dim_embedding *
                                         num_layers,
                                         dim_edge_features=dim_edge_features,
                                         dim_target=dim_target,
                                         config=config)

        self.layers = nn.ModuleList([])
        for i in range(num_layers):
            dim_input = dim_node_features if i == 0 else dim_embedding

            conv = SAGEConv(dim_input, dim_embedding)
            # Overwrite aggregation method (default is set to mean
            conv.aggr = self.aggregation

            self.layers.append(conv)
Beispiel #2
0
    def __init__(self, dim_features, dim_target, config):
        super().__init__()

        num_layers = config['num_layers']
        dim_embedding = config['dim_embedding']
        self.aggregation = config['aggregation']  # can be mean or max

        if self.aggregation == 'max':
            self.fc_max = nn.Linear(dim_embedding, dim_embedding)

        self.layers = nn.ModuleList([])
        for i in range(num_layers):
            dim_input = dim_features if i == 0 else dim_embedding

            conv = SAGEConv(dim_input, dim_embedding)
            # Overwrite aggregation method (default is set to mean
            conv.aggr = self.aggregation

            self.layers.append(conv)

        # For graph classification
        self.fc1 = nn.Linear(num_layers * dim_embedding, dim_embedding)
        self.fc2 = nn.Linear(dim_embedding, dim_target)
    def __init__(self,
                 n_feat,
                 n_class,
                 n_layer,
                 agg_hidden,
                 fc_hidden,
                 dropout,
                 readout,
                 device,
                 aggregation='mean'):

        super(GraphSAGE, self).__init__()

        self.n_layer = n_layer
        self.dropout = dropout
        self.readout = readout
        self.aggregation = aggregation
        self.device = device
        self.readout_dim = agg_hidden * n_layer

        # Graph sage layer
        self.graph_sage_layers = []
        for i in range(n_layer):
            if i == 0:
                sage = SAGEConv(n_feat, agg_hidden).to(device)
            else:
                sage = SAGEConv(agg_hidden, agg_hidden).to(device)
            sage.aggr = self.aggregation
            self.graph_sage_layers.append(sage)

        if self.aggregation == 'max':
            self.fc_max = nn.Linear(agg_hidden, agg_hidden)

        # Fully-connected layer
        self.fc1 = nn.Linear(self.readout_dim, fc_hidden)
        self.fc2 = nn.Linear(fc_hidden, n_class)