def test_sparse_tensor_spspmm(dtype, device): x = SparseTensor( row=torch.tensor( [0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9], device=device), col=torch.tensor( [0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15], device=device), value=torch.tensor([ 1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5, -2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5, -2**-0.5, 2**-0.5, -2**-0.5 ], dtype=dtype, device=device), ) expected = torch.eye(10, dtype=dtype, device=device) out = x @ x.to_dense().t() assert torch.allclose(out, expected, atol=1e-7) out = x @ x.t() out = out.to_dense() assert torch.allclose(out, expected, atol=1e-7)
def build_batch(data: torch_geometric.data.Batch, id2graphlet: Dict[int, Subgraph], common_file=None) -> "Batch": # Check if graphlet_id == 0 exists in x because of how remap works graphlet_id_zero = (data.x == 0).any().item() # remap graphlet_ids to (0..len(data.x.unique())) remapped_graphlet_ids, mapping = renumber( data.x.numpy(), start=0 + int(graphlet_id_zero), in_place=False, preserve_zero=graphlet_id_zero) batch_graphlet_indices = torch.tensor(remapped_graphlet_ids.flatten(), dtype=torch.int64) graphlet_ids_sorted_by_new_id = ( list( map( lambda e: e[0], # get key of # (key, value) sorted by value sorted(mapping.items(), key=lambda e: e[1])))) xs = [] edge_indices = [] # create list of xs and edge_indices where # xs[i] are the features of the graphlet that was mapped to # new_id == i etc. for i, graphlet_id in enumerate(graphlet_ids_sorted_by_new_id): graphlet = id2graphlet[graphlet_id] xs.append(graphlet.x) edge_indices.append(graphlet.edge_index + i * graphlet.x.size(0)) if common_file is not None: common = np.loadtxt(str(common_file), dtype=np.int64) common = torch.tensor( list(map(lambda x: mapping.get(x, -100), common))) common = (batch_graphlet_indices == common.reshape(-1, 1)).any(0) data.estimates[~common] = 0 # Sparse matrix where each row represents a graph # and each column a graphlet where # m[graph][graphlet] == count of graphlet in graph graph_has_graphlet = SparseTensor(row=data.batch, col=batch_graphlet_indices, value=data.estimates) if graph_has_graphlet.density( ) > .75: # FIXME: update parameter if necessary graph_has_graphlet = graph_has_graphlet.to_dense() return Batch(x=torch.cat(xs, dim=0), edge_index=torch.cat(edge_indices, dim=1), graph_has_graphlet=graph_has_graphlet, graphlet_ids=graphlet_ids_sorted_by_new_id, y=data.y)
def __init__(self, num_nodes: int, adj_t: torch_sparse.SparseTensor, num_heads: int, embedding_dim: int, num_sentinals: int, num_samples: int, k_nearest: int = 8, sentinal_dist: float = 1., distance: str = 'euclidian', sample_weights: str = 'softmin', num_weight_layers: int = 3, hidden_weight_dim: int = 32, thresh_weight: int = 1): ''' distance: how to compute the weighting of different samples sample_weights: how the weight for each sample is computed num_weight_layers: how many layers in the weight NN hidden_weight_dim: size of weight computing layer in the weight NN ''' super(MADEdgePredictor, self).__init__() self.num_heads = num_heads self.num_nodes = num_nodes self.num_sentinals = num_sentinals self.embedding_dim = embedding_dim self.num_samples = num_samples self.k_nearest = k_nearest self.distance = distance self.sentinal_dist = sentinal_dist self.sample_weights = sample_weights self.thresh_weight = thresh_weight # nn to apply to adj_t labels self.label_nn = nn.Linear(1, 1, bias=False) self.adj = adj_t.to_dense() * 2 - 1 # Scale from -1 to 1 # nn.init.xavier_normal(self.label_nn.weight) self.label_nn.weight = nn.Parameter( torch.ones_like(self.label_nn.weight)) assert self.distance in ['euclidian', 'inner', 'dot'], 'Distance metric invalid' assert self.sample_weights in ['softmin', 'attention' ], 'Sample weight method invalid' # weighted inner product xTWx if self.distance == 'inner': self.W = nn.Linear(self.embedding_dim, self.embedding_dim) if self.sample_weights == 'attention': self.atn = MADAttention(embedding_dim, hidden_weight_dim, 1, num_weight_layers)
def test_to_symmetric(device): row = torch.tensor([0, 0, 0, 1, 1], device=device) col = torch.tensor([0, 1, 2, 0, 2], device=device) value = torch.arange(1, 6, device=device) mat = SparseTensor(row=row, col=col, value=value) assert not mat.is_symmetric() mat = mat.to_symmetric() assert mat.is_symmetric() assert mat.to_dense().tolist() == [ [2, 6, 3], [6, 0, 5], [3, 5, 0], ]
def process(self): with open(self.raw_paths[0], 'r') as f: data = [x.split('\t') for x in f.read().split('\n')[1:-1]] rows, cols = [], [] for n_id, col, _ in data: col = [int(x) for x in col.split(',')] rows += [int(n_id)] * len(col) cols += col x = SparseTensor(row=torch.tensor(rows), col=torch.tensor(cols)) x = x.to_dense() y = torch.empty(len(data), dtype=torch.long) for n_id, _, label in data: y[int(n_id)] = int(label) with open(self.raw_paths[1], 'r') as f: data = f.read().split('\n')[1:-1] data = [[int(v) for v in r.split('\t')] for r in data] edge_index = torch.tensor(data, dtype=torch.long).t().contiguous() edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0)) train_masks, val_masks, test_masks = [], [], [] for f in self.raw_paths[2:]: tmp = np.load(f) train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)] val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)] test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)] train_mask = torch.stack(train_masks, dim=1) val_mask = torch.stack(val_masks, dim=1) test_mask = torch.stack(test_masks, dim=1) data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) data = data if self.pre_transform is None else self.pre_transform(data) torch.save(self.collate([data]), self.processed_paths[0])
def sample(self, batch): batch = torch.tensor(batch) row, col, _ = self.adj_t.coo() # For each node in `batch`, we sample a direct neighbor (as positive # example) and a random node (as negative example): pos_batch = random_walk(row, col, batch, walk_length=2, coalesced=False)[:, 1] # print("batch:", batch) weighted_adj = SparseTensor(row=row, col=col, value=self.edge_weights.cpu()) # print(len(self.edge_weights)) # print(len(row)) # print("adj:", self.adj_t) # print("weighted:", weighted_adj) connections = (weighted_adj.to_dense()).numpy() # [:self.num_doc_nodes] # print("connections:", connections) # print("connections of first doc:", connections[0]) # print("batch:", batch.numpy()) # print("batch sums:", connections[batch.numpy()]) # Quick non-optimised implementation since this can be run in parralel anyway friends = [] for doc in batch: doc_number = doc.numpy() word_conns = connections[doc_number][self.num_doc_nodes:] # print("doc #:", doc_number) # print("connections:", word_conns, "(%i in total)" % sum(word_conns)) # these where's can now be removed now we use p I know ok mask = np.where(word_conns > 0)[0] if len(mask) == 0: random_doc_back = np.random.randint(0, self.num_doc_nodes) # if the word is exclusive to that doc choose a random doc, this should not happen often friends.append(random_doc_back) continue p = word_conns[mask] / sum(word_conns[mask]) random_word = np.random.choice(mask, p=p) + self.num_doc_nodes # print("chosen word:", random_word) rwords_conns = connections[random_word] rwords_conns_to_docs = rwords_conns[:self.num_doc_nodes] # print("That word is connected to the docs:", rwords_conns_to_docs, "(%i in total)" % sum(rwords_conns_to_docs)) mask = np.where(rwords_conns_to_docs > 0)[0] mask = mask[mask != doc_number] p = rwords_conns_to_docs[mask] / sum(rwords_conns_to_docs[mask]) if len(mask) == 0: random_doc_back = np.random.randint(0, self.num_doc_nodes) # if the word is exclusive to that doc choose a random doc, this should not happen often else: random_doc_back = np.random.choice(mask, p=p) friends.append(random_word) # print("%i CHOSEN DOC FRIEND:" % doc_number, random_doc_back) # my_conns = connections[doc_number] # her_conns = connections[random_doc_back] # print("We have %i words in common" % (sum((my_conns > 0) & (her_conns > 0)))) # print("[", end="") # for j in range(self.num_doc_nodes): # their_conns = sum((my_conns > 0) & (connections[j] > 0)) # their_conns = sum((my_conns > 0) & (connections[j] > 0)) # print(their_conns, end=", ") pos_batch = torch.tensor(friends, dtype=torch.long) neg_batch = torch.randint(self.num_doc_nodes, self.adj_t.size(1), (batch.numel(), ), dtype=torch.long) # neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ), # dtype=torch.long) batch = torch.cat([batch, pos_batch, neg_batch], dim=0) return super(PosNegNeighborSampler, self).sample(batch)