def loss(self, data): x = self.forward(data.x, data.edge_index) device = x.device # if self.walk_res is None: self.walk_res = random_walk(data.edge_index[0], data.edge_index[1], start=torch.arange(0, x.shape[0]).to(device), walk_length=self.walk_length)[:, 1:] if not self.num_nodes: self.num_nodes = int(torch.max(data.edge_index)) + 1 # if self.negative_samples is None: self.negative_samples = torch.from_numpy( np.random.choice( self.num_nodes, (self.num_nodes, self.num_negative_samples))).to(device) pos_loss = -torch.log( torch.sigmoid( torch.sum(x.unsqueeze(1).repeat(1, self.walk_length, 1) * x[self.walk_res], dim=-1))).mean() neg_loss = -torch.log( torch.sigmoid(-torch.sum(x.unsqueeze(1).repeat( 1, self.num_negative_samples, 1) * x[self.negative_samples], dim=-1))).mean() return (pos_loss + neg_loss) / 2
def get_walks(training, batch, output, GRAPH_SIZES, SOURCE_NODES, SINK_NODES): if training: return random_walk(batch.edge_index[0], batch.edge_index[1], SINK_NODES, walk_length=get_hyperparameters()["walk_length"], coalesced=True).long(), None path_matrix, _, mask = get_walks_from_output(output, GRAPH_SIZES, SOURCE_NODES, SINK_NODES) return path_matrix, mask
def __random_walk__(self, edge_index, subset=None): if subset is None: subset = torch.arange(self.num_nodes, device=edge_index.device) subset = subset.repeat(self.walks_per_node) rw = random_walk(edge_index[0], edge_index[1], subset, self.walk_length, self.p, self.q, self.num_nodes) walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0)
def test_rw(device): row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device) start = tensor([0, 1, 2, 3, 4], torch.long, device) walk_length = 10 out = random_walk(row, col, start, walk_length) assert out[:, 0].tolist() == start.tolist() for n in range(start.size(0)): cur = start[n].item() for i in range(1, walk_length): assert out[n, i].item() in col[row == cur].tolist() cur = out[n, i].item() row = tensor([0, 1], torch.long, device) col = tensor([1, 0], torch.long, device) start = tensor([0, 1, 2], torch.long, device) walk_length = 4 out = random_walk(row, col, start, walk_length, num_nodes=3) assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]
def sample(self, batch): batch = torch.tensor(batch) row, col, _ = self.adj_t.coo() # For each node in `batch`, we sample a direct neighbor (as positive # example) and a random node (as negative example): pos_batch = random_walk(row, col, batch, walk_length=1, coalesced=False)[:, 1] neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ), dtype=torch.long) batch = torch.cat([batch, pos_batch, neg_batch], dim=0) return super(NeighborSampler, self).sample(batch)
def test_rw(device): row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device) col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device) start = tensor([0, 1, 2, 3, 4], torch.long, device) walk_length = 10 out = random_walk(row, col, start, walk_length, coalesced=True) assert out[:, 0].tolist() == start.tolist() for n in range(start.size(0)): cur = start[n].item() for l in range(1, walk_length): assert out[n, l].item() in col[row == cur].tolist() cur = out[n, l].item()
def sample(self, batch): new_batch = torch.tensor(batch) row, col, _ = self.adj_t.coo() pos_batch = random_walk(row, col, new_batch, walk_length=1, coalesced=False)[:, 1] neg_batch = torch.randint(0, self.adj_t.size(1), (new_batch.numel(), ), dtype=torch.long) new_batch = torch.cat([new_batch, pos_batch, neg_batch], dim=0) return super().sample(new_batch)
def extract( self, data: Data, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: from torch_cluster import random_walk start = torch.arange(data.num_nodes, device=data.edge_index.device) start = start.view(-1, 1).repeat(1, self.repeat).view(-1) walk = random_walk(data.edge_index[0], data.edge_index[1], start, self.walk_length, num_nodes=data.num_nodes) n_mask = torch.zeros((data.num_nodes, data.num_nodes), dtype=torch.bool, device=walk.device) start = start.view(-1, 1).repeat(1, (self.walk_length + 1)).view(-1) n_mask[start, walk.view(-1)] = True return self.map(data, n_mask)
def train(dataset, device): for graph in dataset: x, edge_index, y = graph.x, graph.edge_index, graph.y edge_index = edge_index.to(device) x = x.to(device) y = y.to(device) subset = torch.arange(x.size(0), device=edge_index.device) # subset = subset[:np.min((50, x.size(0)))] # if x.size(0) >= 100: # print('over 100') # print(x.size()) # print(x.dtype) # print(edge_index[0].dtype) # print(edge_index[1].dtype) walks = random_walk(edge_index[0], edge_index[1], subset, walk_length, p, q, x.size(0)) # if (walks.size(0) == 0): # print('zero sized walks') nan_check_and_break(walks, "walks nan check") return
def validate(model, dataset, device): model.eval() batch_acc = 0 for graph in dataset: x, edge_index, y = graph.x, graph.edge_index, graph.y edge_index = edge_index.to(device) x = x.to(device) y = y.to(device) subset = torch.arange(x.size(0), device=edge_index.device) walks = random_walk(edge_index[0], edge_index[1], subset, walk_length, p, q, x.size(0)) # model forward y = y.repeat(walks.size(0), ) y_pred = model(walks) pred = y_pred.max(dim=1)[1] correct = pred.eq(y).sum().item() correct /= y.size(0) batch_acc += (correct * 100) return batch_acc / len(dataset)
def train(model, dataset, opt, sch, loss_func, device): model.train() batch_loss = 0 for graph in dataset: x, edge_index, y = graph.x, graph.edge_index, graph.y edge_index = edge_index.to(device) x = x.to(device) y = y.to(device) subset = torch.arange(x.size(0), device=edge_index.device) walks = random_walk(edge_index[0], edge_index[1], subset, walk_length, p, q, x.size(0)) if (walks.size(0) == 0): print('zero sized walks') nan_check_and_break(walks, "walks nan check") # model forward y = y.repeat(walks.size(0), ) y_pred = model(walks) nan_check_and_break(y_pred, "y_pred") loss = loss_func(y_pred, y) print(loss.item()) batch_loss += loss.item() / walks.size(0) opt.zero_grad() loss.backward() opt.step() sch.step() return batch_loss / len(dataset)
def train(epoch): model.train() pbar = tqdm(total=data.num_nodes) pbar.set_description(f'Epoch {epoch:02d}') total_loss = 0 node_idx = torch.randperm(data.num_nodes) # shuffle all nodes train_loader = NeighborSampler( data.edge_index, node_idx=node_idx, sizes=num_samples, num_nodes=data.num_nodes, batch_size=batch_size, shuffle=False, num_workers=0) rw = random_walk( data.edge_index[0], data.edge_index[1], node_idx, walk_length=walk_length, num_nodes=data.num_nodes, coalesced=True) rw_idx = rw[:, 1:].flatten() # hack to bring out of bounds values in bounds. There is something wrong # with random_walk. No clue what effect this has on the model. rw_idx = torch.where(rw_idx > data.num_nodes - 1, rw_idx % data.num_nodes, rw_idx) pos_loader = NeighborSampler( data.edge_index, node_idx=rw_idx, sizes=num_samples, num_nodes=data.num_nodes, batch_size=batch_size * walk_length, shuffle=False, num_workers=0) neg_loader = NeighborSampler( data.edge_index, node_idx=node_idx, sizes=num_samples, num_nodes=data.num_nodes, batch_size=batch_size, shuffle=True, num_workers=0) for (batch_size_, u_id, adjs_u), (_, v_id, adjs_v), (_, vn_id, adjs_vn) in\ zip(train_loader, pos_loader, neg_loader): # `adjs` holds a list of `(edge_index, e_id, size)` tuples. adjs_u = [adj.to(device) for adj in adjs_u] z_u = model(x[u_id], adjs_u) adjs_v = [adj.to(device) for adj in adjs_v] z_v = model(x[v_id], adjs_v) adjs_vn = [adj.to(device) for adj in adjs_vn] z_vn = model(x[vn_id], adjs_vn) optimizer.zero_grad() pos_loss = -F.logsigmoid( (z_u.repeat_interleave(walk_length, dim=0)*z_v).sum(dim=1)).mean() neg_loss = -F.logsigmoid( -(z_u*z_vn).sum(dim=1)).mean() loss = pos_loss + neg_loss loss.backward() optimizer.step() total_loss += loss.item() pbar.update(batch_size_) pbar.close() loss = total_loss / len(train_loader) return loss
def subsample_walk(self, sample_args=None): r"""Takes a subsample of the graph by performing random walks at various randomly selected vertices of particular lengths. Args: sample_args (dict) : Dictionary with the following keys/items: start_num (int) : Decides the number of randomly selected nodes to begin random walks at. walk_length (int) : Decides the length of each random walk to perform. Returns: subsample_dict (dict) : A dictionary with the following three keys: 'node_ind' : The indices of the nodes which are part of the subsampled graph. 'edge_ind' : The indices of the edges which are part of the subsampled graph, with respect to the ordering given in self.edge_index and self.edge_weight. 'edge_list' : Pairs of edges which are part of the subsampled graph, stored as a Tensor of shape [2, num of subsampled edges]. """ # Extract values from sample_dict start_num = sample_args['start_num'] walk_length = sample_args['walk_length'] # Choose random starting locations start_loc = torch.as_tensor(np.random.choice( self.nodes_pos_degree_mask.numpy(), size=start_num, replace=False), dtype=torch.long) # Perform random walks; note that walks has shape # [start_num, walk_length + 1] walks = random_walk(row=self.edge_index[0, self.edges_pos_degree_mask], col=self.edge_index[1, self.edges_pos_degree_mask], start=start_loc, num_nodes=self.num_nodes, walk_length=walk_length, coalesced=True) # Convert into edge list edge_list = torch.zeros(size=(2, start_num * walk_length), dtype=torch.long) for i in range(start_num): ind = i * walk_length + np.arange(walk_length) ind = list(ind) edge_list[0, ind] = walks[i, :-1] edge_list[1, ind] = walks[i, 1:] # Get rid of duplicates in edge list edge_list = torch.unique(edge_list, dim=1) # Extract nodes and the indices of edge_list with respect to # self.edge_index node_ind = torch.unique(edge_list) edge_ind = torch.zeros(size=[edge_list.shape[1]], dtype=torch.long) for i in range(edge_list.shape[1]): edge_ind[i] = torch.where( (self.edge_index[0, :] == edge_list[0, i]) & (self.edge_index[1, :] == edge_list[1, i]))[0][0] subsample_dict = { 'node_ind': node_ind, 'edge_ind': edge_ind, 'edge_list': edge_list } return subsample_dict
def sample(self, batch): batch = torch.tensor(batch) row, col, _ = self.adj_t.coo() # For each node in `batch`, we sample a direct neighbor (as positive # example) and a random node (as negative example): pos_batch = random_walk(row, col, batch, walk_length=2, coalesced=False)[:, 1] # print("batch:", batch) weighted_adj = SparseTensor(row=row, col=col, value=self.edge_weights.cpu()) # print(len(self.edge_weights)) # print(len(row)) # print("adj:", self.adj_t) # print("weighted:", weighted_adj) connections = (weighted_adj.to_dense()).numpy() # [:self.num_doc_nodes] # print("connections:", connections) # print("connections of first doc:", connections[0]) # print("batch:", batch.numpy()) # print("batch sums:", connections[batch.numpy()]) # Quick non-optimised implementation since this can be run in parralel anyway friends = [] for doc in batch: doc_number = doc.numpy() word_conns = connections[doc_number][self.num_doc_nodes:] # print("doc #:", doc_number) # print("connections:", word_conns, "(%i in total)" % sum(word_conns)) # these where's can now be removed now we use p I know ok mask = np.where(word_conns > 0)[0] if len(mask) == 0: random_doc_back = np.random.randint(0, self.num_doc_nodes) # if the word is exclusive to that doc choose a random doc, this should not happen often friends.append(random_doc_back) continue p = word_conns[mask] / sum(word_conns[mask]) random_word = np.random.choice(mask, p=p) + self.num_doc_nodes # print("chosen word:", random_word) rwords_conns = connections[random_word] rwords_conns_to_docs = rwords_conns[:self.num_doc_nodes] # print("That word is connected to the docs:", rwords_conns_to_docs, "(%i in total)" % sum(rwords_conns_to_docs)) mask = np.where(rwords_conns_to_docs > 0)[0] mask = mask[mask != doc_number] p = rwords_conns_to_docs[mask] / sum(rwords_conns_to_docs[mask]) if len(mask) == 0: random_doc_back = np.random.randint(0, self.num_doc_nodes) # if the word is exclusive to that doc choose a random doc, this should not happen often else: random_doc_back = np.random.choice(mask, p=p) friends.append(random_word) # print("%i CHOSEN DOC FRIEND:" % doc_number, random_doc_back) # my_conns = connections[doc_number] # her_conns = connections[random_doc_back] # print("We have %i words in common" % (sum((my_conns > 0) & (her_conns > 0)))) # print("[", end="") # for j in range(self.num_doc_nodes): # their_conns = sum((my_conns > 0) & (connections[j] > 0)) # their_conns = sum((my_conns > 0) & (connections[j] > 0)) # print(their_conns, end=", ") pos_batch = torch.tensor(friends, dtype=torch.long) neg_batch = torch.randint(self.num_doc_nodes, self.adj_t.size(1), (batch.numel(), ), dtype=torch.long) # neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ), # dtype=torch.long) batch = torch.cat([batch, pos_batch, neg_batch], dim=0) return super(PosNegNeighborSampler, self).sample(batch)