Exemplo n.º 1
0
    def forward(self, x, edge_index, edge_attr, params, param_name_dict, size=None):
        self.att = get_param(params, param_name_dict, "att")
        self.edge_update = get_param(params, param_name_dict, "edge_update")
        self.bias = None
        if self.use_bias:
            self.bias = get_param(params, param_name_dict, "bias")
        if size is None and torch.is_tensor(x):
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # edge_index = add_self_loops(edge_index, num_nodes=x.size(0))
        self_loop_edges = torch.zeros(x.size(0), edge_attr.size(1)).to(
            edge_index.device
        )
        edge_attr = torch.cat([edge_attr, self_loop_edges], dim=0)  # (500, 10)
        # Note: we need to add blank edge attributes for self loops
        weight = get_param(params, param_name_dict, "weight")
        if torch.is_tensor(x):
            x = torch.matmul(x, weight)
        else:
            x = (
                None if x[0] is None else torch.matmul(x[0], weight),
                None if x[1] is None else torch.matmul(x[1], weight),
            )
        # x = x.view(-1, self.heads, self.out_channels)
        # x = torch.mm(x, weight).view(-1, self.heads, self.out_channels)
        return self.propagate(
            edge_index, size=size, x=x, num_nodes=x.size(0), edge_attr=edge_attr
        )
Exemplo n.º 2
0
    def forward(
        self,
        x,
        edge_index,
        edge_types,
        relation_weights,
        params,
        param_name_dict,
        edge_norm=None,
        size=None,
    ):
        self.basis = get_param(params, param_name_dict, "basis")
        if self.root_weight:
            self.root = get_param(params, param_name_dict, "root")
        if self.use_bias:
            self.bias = get_param(params, param_name_dict, "bias")

        return self.propagate(
            edge_index,
            size=size,
            x=x,
            edge_types=edge_types,
            edge_norm=edge_norm,
            relation_weights=relation_weights,
        )
Exemplo n.º 3
0
    def forward(self,
                x,
                edge_index,
                params,
                param_name_dict,
                edge_weight=None):
        """"""
        self.weight = get_param(params, param_name_dict, "weight")
        if self.use_bias:
            self.bias = get_param(params, param_name_dict, "bias")
        x = torch.matmul(x, self.weight)

        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    "Cached {} number of edges, but found {}. Please "
                    "disable the caching behavior of this layer by removing "
                    "the `cached=True` argument in its constructor.".format(
                        self.cached_num_edges, edge_index.size(1)))

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            if self.normalize:
                edge_index, norm = self.norm(edge_index, x.size(0),
                                             edge_weight, self.improved,
                                             x.dtype)
            else:
                norm = edge_weight
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result

        return self.propagate(edge_index, x=x, norm=norm)
Exemplo n.º 4
0
    def forward(self, batch):
        param_name_to_idx = {k: v for v, k in enumerate(self.weight_names)}

        edge_e = F.linear(
            get_param(self.weights, param_name_to_idx, "learned_param_1"),
            weight=get_param(self.weights, param_name_to_idx,
                             "learned_param_2").t(),
        )
        return edge_e, None
Exemplo n.º 5
0
 def pyg_classify(self,
                  nodes,
                  query_edge,
                  params=None,
                  param_name_dict=None):
     """
     Run classification using MLP
     :param nodes:
     :param query_edge:
     :param params:
     :param param_name_dict:
     :return:
     """
     query_emb = []
     for i in range(len(nodes)):
         query = (query_edge[i].unsqueeze(0).unsqueeze(2).repeat(
             1, 1, nodes[i].size(2)))  # B x num_q x dim
         query_emb.append(torch.gather(nodes[i], 1, query))
     query_emb = torch.cat(query_emb, dim=0)
     query = query_emb.view(query_emb.size(0), -1)  # B x (num_q x dim)
     # pool the nodes
     # mean pooling
     node_avg = torch.cat(
         [torch.mean(nodes[i], 1) for i in range(len(nodes))],
         dim=0)  # B x dim
     # concat the query
     edges = torch.cat((node_avg, query), -1)  # B x (dim + dim x num_q)
     for layer in range(self.config.model.classify_layers):
         edges = F.linear(
             edges,
             weight=get_param(
                 params,
                 param_name_dict,
                 "classify_{}.weight".format(layer),
                 ignore_classify=False,
             ).t(),
             bias=get_param(
                 params,
                 param_name_dict,
                 "classify_{}.bias".format(layer),
                 ignore_classify=False,
             ),
         )
         if layer < self.config.model.classify_layers - 1:
             edges = F.relu(edges)
     return edges
Exemplo n.º 6
0
    def forward(self, x, edge_index, edge_attr, params, param_name_dict, size=None):
        self.att = get_param(params, param_name_dict, "att")
        # self.edge_update = params[self.get_param_id(param_name_dict, 'edge_update')]
        self.bias = None
        if self.use_bias:
            self.bias = get_param(params, param_name_dict, "bias")
        if size is None and torch.is_tensor(x):
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # get gru params
        self.gru_weight_ih = get_param(params, param_name_dict, "gru_w_ih")
        self.gru_weight_hh = get_param(params, param_name_dict, "gru_w_hh")
        self.gru_bias_ih = get_param(params, param_name_dict, "gru_b_ih")
        self.gru_bias_hh = get_param(params, param_name_dict, "gru_b_hh")
        self.gru_hx = x

        # Note: we need to add blank edge attributes for self loops
        weight = get_param(params, param_name_dict, "weight")
        if torch.is_tensor(x):
            x = torch.matmul(x, weight)
        else:
            x = (
                None if x[0] is None else torch.matmul(x[0], weight),
                None if x[1] is None else torch.matmul(x[1], weight),
            )
        return self.propagate(
            edge_index, size=size, x=x, num_nodes=x.size(0), edge_attr=edge_attr
        )
Exemplo n.º 7
0
    def forward(self, batch):
        data = batch.world_graphs
        param_name_to_idx = {k: v for v, k in enumerate(self.weight_names)}
        assert data.x.size(0) == data.edge_indicator.size(0)
        # extract node embeddings
        # data.edge_indicator contains 0's for all nodes and value > 0 for each unique relations
        x = F.embedding(
            data.edge_indicator,
            get_param(self.weights, param_name_to_idx, "common_emb"),
        )
        # edge attribute is None because we are not learning edge types here
        edge_attr = None
        if data.edge_index.dim() != 2:
            import ipdb

            ipdb.set_trace()
        for nr in range(self.config.model.signature_gat.num_layers - 1):
            param_name_dict = self.prepare_param_idx(nr)
            x = F.dropout(
                x, p=self.config.model.signature_gat.dropout, training=self.training
            )
            x = self.edgeConvs[nr](
                x, data.edge_index, edge_attr, self.weights, param_name_dict
            )
            x = F.elu(x)
        x = F.dropout(
            x, p=self.config.model.signature_gat.dropout, training=self.training
        )
        param_name_dict = self.prepare_param_idx(
            self.config.model.signature_gat.num_layers - 1
        )
        if self.config.model.signature_gat.num_layers > 0:
            x = self.edgeConvs[self.config.model.signature_gat.num_layers - 1](
                x, data.edge_index, edge_attr, self.weights, param_name_dict
            )
        # restore x into B x num_node x dim
        chunks = torch.split(x, batch.num_edge_nodes, dim=0)
        batches = [p.unsqueeze(0) for p in chunks]
        # we only have one batch for world graph
        batch = batches[0][0]
        # sum over edge type nodes
        num_class = self.config.model.num_classes
        edge_emb = torch.zeros((num_class, batch.size(-1)))
        edge_emb = edge_emb.to(self.config.general.device)
        for ei_t in data.edge_indicator.unique():
            ei = ei_t.item()
            if ei == 0:
                # node of type "node", skip
                continue
            # node of type "edge", take
            # we subtract 1 here to re-align the vectors (L399 of data.py)
            edge_emb[ei - 1] = batch[data.edge_indicator == ei].mean(dim=0)
        return edge_emb, batch
Exemplo n.º 8
0
    def forward(self, batch, rel_emb=None):
        data = batch.graphs
        param_name_to_idx = {k: v for v, k in enumerate(self.weight_names)}
        # initialize random node embeddings

        node_emb = torch.Tensor(
            size=(self.config.model.num_nodes,
                  self.config.model.relation_embedding_dim)).to(
                      self.config.general.device)
        torch.nn.init.xavier_uniform_(node_emb, gain=1.414)
        x = F.embedding(data.x, node_emb)
        # x = F.embedding(data.x, self.weights[self.get_param_id(param_name_to_idx,
        #                                                        'node_embedding')])
        x = x.squeeze(1)
        # x = self.embedding(data.x).squeeze(1) # N x node_dim
        if rel_emb is not None:
            edge_attr = F.embedding(data.edge_attr, rel_emb)
        else:
            edge_attr = F.embedding(
                data.edge_attr,
                get_param(self.weights, param_name_to_idx,
                          "relation_embedding"),
            )
        edge_attr = edge_attr.squeeze(1)
        # edge_attr = self.edge_embedding(data.edge_attr).squeeze(1) # E x edge_dim
        for nr in range(self.config.model.gat.num_layers - 1):
            param_name_dict = self.prepare_param_idx(nr)
            x = F.dropout(x,
                          p=self.config.model.gat.dropout,
                          training=self.training)
            x = self.edgeConvs[nr](x, data.edge_index, edge_attr, self.weights,
                                   param_name_dict)
            x = F.elu(x)
        x = F.dropout(x,
                      p=self.config.model.gat.dropout,
                      training=self.training)
        param_name_dict = self.prepare_param_idx(
            self.config.model.gat.num_layers - 1)
        if self.config.model.gat.num_layers > 0:
            x = self.edgeConvs[self.config.model.gat.num_layers - 1](
                x, data.edge_index, edge_attr, self.weights, param_name_dict)
        # restore x into B x num_node x dim
        chunks = torch.split(x, batch.num_nodes, dim=0)
        chunks = [p.unsqueeze(0) for p in chunks]
        # x = torch.cat(chunks, dim=0)
        return self.pyg_classify(chunks, batch.queries, self.weights,
                                 param_name_to_idx)
Exemplo n.º 9
0
 def forward(self, batch):
     param_name_to_idx = {k: v for v, k in enumerate(self.weight_names)}
     edge_e = get_param(self.weights, param_name_to_idx, "learned_param")
     return edge_e, None