def forward(self, x, edge_index, edge_attr, params, param_name_dict, size=None): self.att = get_param(params, param_name_dict, "att") self.edge_update = get_param(params, param_name_dict, "edge_update") self.bias = None if self.use_bias: self.bias = get_param(params, param_name_dict, "bias") if size is None and torch.is_tensor(x): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # edge_index = add_self_loops(edge_index, num_nodes=x.size(0)) self_loop_edges = torch.zeros(x.size(0), edge_attr.size(1)).to( edge_index.device ) edge_attr = torch.cat([edge_attr, self_loop_edges], dim=0) # (500, 10) # Note: we need to add blank edge attributes for self loops weight = get_param(params, param_name_dict, "weight") if torch.is_tensor(x): x = torch.matmul(x, weight) else: x = ( None if x[0] is None else torch.matmul(x[0], weight), None if x[1] is None else torch.matmul(x[1], weight), ) # x = x.view(-1, self.heads, self.out_channels) # x = torch.mm(x, weight).view(-1, self.heads, self.out_channels) return self.propagate( edge_index, size=size, x=x, num_nodes=x.size(0), edge_attr=edge_attr )
def forward( self, x, edge_index, edge_types, relation_weights, params, param_name_dict, edge_norm=None, size=None, ): self.basis = get_param(params, param_name_dict, "basis") if self.root_weight: self.root = get_param(params, param_name_dict, "root") if self.use_bias: self.bias = get_param(params, param_name_dict, "bias") return self.propagate( edge_index, size=size, x=x, edge_types=edge_types, edge_norm=edge_norm, relation_weights=relation_weights, )
def forward(self, x, edge_index, params, param_name_dict, edge_weight=None): """""" self.weight = get_param(params, param_name_dict, "weight") if self.use_bias: self.bias = get_param(params, param_name_dict, "bias") x = torch.matmul(x, self.weight) if self.cached and self.cached_result is not None: if edge_index.size(1) != self.cached_num_edges: raise RuntimeError( "Cached {} number of edges, but found {}. Please " "disable the caching behavior of this layer by removing " "the `cached=True` argument in its constructor.".format( self.cached_num_edges, edge_index.size(1))) if not self.cached or self.cached_result is None: self.cached_num_edges = edge_index.size(1) if self.normalize: edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, self.improved, x.dtype) else: norm = edge_weight self.cached_result = edge_index, norm edge_index, norm = self.cached_result return self.propagate(edge_index, x=x, norm=norm)
def forward(self, batch): param_name_to_idx = {k: v for v, k in enumerate(self.weight_names)} edge_e = F.linear( get_param(self.weights, param_name_to_idx, "learned_param_1"), weight=get_param(self.weights, param_name_to_idx, "learned_param_2").t(), ) return edge_e, None
def pyg_classify(self, nodes, query_edge, params=None, param_name_dict=None): """ Run classification using MLP :param nodes: :param query_edge: :param params: :param param_name_dict: :return: """ query_emb = [] for i in range(len(nodes)): query = (query_edge[i].unsqueeze(0).unsqueeze(2).repeat( 1, 1, nodes[i].size(2))) # B x num_q x dim query_emb.append(torch.gather(nodes[i], 1, query)) query_emb = torch.cat(query_emb, dim=0) query = query_emb.view(query_emb.size(0), -1) # B x (num_q x dim) # pool the nodes # mean pooling node_avg = torch.cat( [torch.mean(nodes[i], 1) for i in range(len(nodes))], dim=0) # B x dim # concat the query edges = torch.cat((node_avg, query), -1) # B x (dim + dim x num_q) for layer in range(self.config.model.classify_layers): edges = F.linear( edges, weight=get_param( params, param_name_dict, "classify_{}.weight".format(layer), ignore_classify=False, ).t(), bias=get_param( params, param_name_dict, "classify_{}.bias".format(layer), ignore_classify=False, ), ) if layer < self.config.model.classify_layers - 1: edges = F.relu(edges) return edges
def forward(self, x, edge_index, edge_attr, params, param_name_dict, size=None): self.att = get_param(params, param_name_dict, "att") # self.edge_update = params[self.get_param_id(param_name_dict, 'edge_update')] self.bias = None if self.use_bias: self.bias = get_param(params, param_name_dict, "bias") if size is None and torch.is_tensor(x): edge_index, _ = remove_self_loops(edge_index) edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # get gru params self.gru_weight_ih = get_param(params, param_name_dict, "gru_w_ih") self.gru_weight_hh = get_param(params, param_name_dict, "gru_w_hh") self.gru_bias_ih = get_param(params, param_name_dict, "gru_b_ih") self.gru_bias_hh = get_param(params, param_name_dict, "gru_b_hh") self.gru_hx = x # Note: we need to add blank edge attributes for self loops weight = get_param(params, param_name_dict, "weight") if torch.is_tensor(x): x = torch.matmul(x, weight) else: x = ( None if x[0] is None else torch.matmul(x[0], weight), None if x[1] is None else torch.matmul(x[1], weight), ) return self.propagate( edge_index, size=size, x=x, num_nodes=x.size(0), edge_attr=edge_attr )
def forward(self, batch): data = batch.world_graphs param_name_to_idx = {k: v for v, k in enumerate(self.weight_names)} assert data.x.size(0) == data.edge_indicator.size(0) # extract node embeddings # data.edge_indicator contains 0's for all nodes and value > 0 for each unique relations x = F.embedding( data.edge_indicator, get_param(self.weights, param_name_to_idx, "common_emb"), ) # edge attribute is None because we are not learning edge types here edge_attr = None if data.edge_index.dim() != 2: import ipdb ipdb.set_trace() for nr in range(self.config.model.signature_gat.num_layers - 1): param_name_dict = self.prepare_param_idx(nr) x = F.dropout( x, p=self.config.model.signature_gat.dropout, training=self.training ) x = self.edgeConvs[nr]( x, data.edge_index, edge_attr, self.weights, param_name_dict ) x = F.elu(x) x = F.dropout( x, p=self.config.model.signature_gat.dropout, training=self.training ) param_name_dict = self.prepare_param_idx( self.config.model.signature_gat.num_layers - 1 ) if self.config.model.signature_gat.num_layers > 0: x = self.edgeConvs[self.config.model.signature_gat.num_layers - 1]( x, data.edge_index, edge_attr, self.weights, param_name_dict ) # restore x into B x num_node x dim chunks = torch.split(x, batch.num_edge_nodes, dim=0) batches = [p.unsqueeze(0) for p in chunks] # we only have one batch for world graph batch = batches[0][0] # sum over edge type nodes num_class = self.config.model.num_classes edge_emb = torch.zeros((num_class, batch.size(-1))) edge_emb = edge_emb.to(self.config.general.device) for ei_t in data.edge_indicator.unique(): ei = ei_t.item() if ei == 0: # node of type "node", skip continue # node of type "edge", take # we subtract 1 here to re-align the vectors (L399 of data.py) edge_emb[ei - 1] = batch[data.edge_indicator == ei].mean(dim=0) return edge_emb, batch
def forward(self, batch, rel_emb=None): data = batch.graphs param_name_to_idx = {k: v for v, k in enumerate(self.weight_names)} # initialize random node embeddings node_emb = torch.Tensor( size=(self.config.model.num_nodes, self.config.model.relation_embedding_dim)).to( self.config.general.device) torch.nn.init.xavier_uniform_(node_emb, gain=1.414) x = F.embedding(data.x, node_emb) # x = F.embedding(data.x, self.weights[self.get_param_id(param_name_to_idx, # 'node_embedding')]) x = x.squeeze(1) # x = self.embedding(data.x).squeeze(1) # N x node_dim if rel_emb is not None: edge_attr = F.embedding(data.edge_attr, rel_emb) else: edge_attr = F.embedding( data.edge_attr, get_param(self.weights, param_name_to_idx, "relation_embedding"), ) edge_attr = edge_attr.squeeze(1) # edge_attr = self.edge_embedding(data.edge_attr).squeeze(1) # E x edge_dim for nr in range(self.config.model.gat.num_layers - 1): param_name_dict = self.prepare_param_idx(nr) x = F.dropout(x, p=self.config.model.gat.dropout, training=self.training) x = self.edgeConvs[nr](x, data.edge_index, edge_attr, self.weights, param_name_dict) x = F.elu(x) x = F.dropout(x, p=self.config.model.gat.dropout, training=self.training) param_name_dict = self.prepare_param_idx( self.config.model.gat.num_layers - 1) if self.config.model.gat.num_layers > 0: x = self.edgeConvs[self.config.model.gat.num_layers - 1]( x, data.edge_index, edge_attr, self.weights, param_name_dict) # restore x into B x num_node x dim chunks = torch.split(x, batch.num_nodes, dim=0) chunks = [p.unsqueeze(0) for p in chunks] # x = torch.cat(chunks, dim=0) return self.pyg_classify(chunks, batch.queries, self.weights, param_name_to_idx)
def forward(self, batch): param_name_to_idx = {k: v for v, k in enumerate(self.weight_names)} edge_e = get_param(self.weights, param_name_to_idx, "learned_param") return edge_e, None