def loss(self, data):
        x = self.forward(data.x, data.edge_index)
        device = x.device
        # if self.walk_res is None:
        self.walk_res = random_walk(data.edge_index[0],
                                    data.edge_index[1],
                                    start=torch.arange(0,
                                                       x.shape[0]).to(device),
                                    walk_length=self.walk_length)[:, 1:]

        if not self.num_nodes:
            self.num_nodes = int(torch.max(data.edge_index)) + 1

        # if self.negative_samples is None:
        self.negative_samples = torch.from_numpy(
            np.random.choice(
                self.num_nodes,
                (self.num_nodes, self.num_negative_samples))).to(device)

        pos_loss = -torch.log(
            torch.sigmoid(
                torch.sum(x.unsqueeze(1).repeat(1, self.walk_length, 1) *
                          x[self.walk_res],
                          dim=-1))).mean()
        neg_loss = -torch.log(
            torch.sigmoid(-torch.sum(x.unsqueeze(1).repeat(
                1, self.num_negative_samples, 1) * x[self.negative_samples],
                                     dim=-1))).mean()
        return (pos_loss + neg_loss) / 2
Exemple #2
0
def get_walks(training, batch, output, GRAPH_SIZES, SOURCE_NODES, SINK_NODES):
    if training:
        return random_walk(batch.edge_index[0],
                           batch.edge_index[1],
                           SINK_NODES,
                           walk_length=get_hyperparameters()["walk_length"],
                           coalesced=True).long(), None
    path_matrix, _, mask = get_walks_from_output(output, GRAPH_SIZES,
                                                 SOURCE_NODES, SINK_NODES)
    return path_matrix, mask
    def __random_walk__(self, edge_index, subset=None):
        if subset is None:
            subset = torch.arange(self.num_nodes, device=edge_index.device)
        subset = subset.repeat(self.walks_per_node)

        rw = random_walk(edge_index[0], edge_index[1], subset,
                         self.walk_length, self.p, self.q, self.num_nodes)

        walks = []
        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size
        for j in range(num_walks_per_rw):
            walks.append(rw[:, j:j + self.context_size])
        return torch.cat(walks, dim=0)
Exemple #4
0
def test_rw(device):
    row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
    col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
    start = tensor([0, 1, 2, 3, 4], torch.long, device)
    walk_length = 10

    out = random_walk(row, col, start, walk_length)
    assert out[:, 0].tolist() == start.tolist()

    for n in range(start.size(0)):
        cur = start[n].item()
        for i in range(1, walk_length):
            assert out[n, i].item() in col[row == cur].tolist()
            cur = out[n, i].item()

    row = tensor([0, 1], torch.long, device)
    col = tensor([1, 0], torch.long, device)
    start = tensor([0, 1, 2], torch.long, device)
    walk_length = 4

    out = random_walk(row, col, start, walk_length, num_nodes=3)
    assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]
    def sample(self, batch):
        batch = torch.tensor(batch)
        row, col, _ = self.adj_t.coo()

        # For each node in `batch`, we sample a direct neighbor (as positive
        # example) and a random node (as negative example):
        pos_batch = random_walk(row, col, batch, walk_length=1,
                                coalesced=False)[:, 1]

        neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ),
                                  dtype=torch.long)

        batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
        return super(NeighborSampler, self).sample(batch)
Exemple #6
0
def test_rw(device):
    row = tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
    col = tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], torch.long, device)
    start = tensor([0, 1, 2, 3, 4], torch.long, device)
    walk_length = 10

    out = random_walk(row, col, start, walk_length, coalesced=True)
    assert out[:, 0].tolist() == start.tolist()

    for n in range(start.size(0)):
        cur = start[n].item()
        for l in range(1, walk_length):
            assert out[n, l].item() in col[row == cur].tolist()
            cur = out[n, l].item()
Exemple #7
0
    def sample(self, batch):
        new_batch = torch.tensor(batch)
        row, col, _ = self.adj_t.coo()

        pos_batch = random_walk(row,
                                col,
                                new_batch,
                                walk_length=1,
                                coalesced=False)[:, 1]

        neg_batch = torch.randint(0,
                                  self.adj_t.size(1), (new_batch.numel(), ),
                                  dtype=torch.long)

        new_batch = torch.cat([new_batch, pos_batch, neg_batch], dim=0)

        return super().sample(new_batch)
    def extract(
        self,
        data: Data,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
        from torch_cluster import random_walk

        start = torch.arange(data.num_nodes, device=data.edge_index.device)
        start = start.view(-1, 1).repeat(1, self.repeat).view(-1)
        walk = random_walk(data.edge_index[0],
                           data.edge_index[1],
                           start,
                           self.walk_length,
                           num_nodes=data.num_nodes)

        n_mask = torch.zeros((data.num_nodes, data.num_nodes),
                             dtype=torch.bool,
                             device=walk.device)
        start = start.view(-1, 1).repeat(1, (self.walk_length + 1)).view(-1)
        n_mask[start, walk.view(-1)] = True

        return self.map(data, n_mask)
def train(dataset, device):
    for graph in dataset:
        x, edge_index, y = graph.x, graph.edge_index, graph.y
        edge_index = edge_index.to(device)
        x = x.to(device)
        y = y.to(device)
        subset = torch.arange(x.size(0), device=edge_index.device)
        # subset = subset[:np.min((50, x.size(0)))]
        # if x.size(0) >= 100:
        #     print('over 100')
        # print(x.size())
        # print(x.dtype)
        # print(edge_index[0].dtype)
        # print(edge_index[1].dtype)
        walks = random_walk(edge_index[0], edge_index[1], subset, walk_length,
                            p, q, x.size(0))
        # if (walks.size(0) == 0):
        #     print('zero sized walks')

        nan_check_and_break(walks, "walks nan check")

    return
def validate(model, dataset, device):
    model.eval()
    batch_acc = 0

    for graph in dataset:
        x, edge_index, y = graph.x, graph.edge_index, graph.y
        edge_index = edge_index.to(device)
        x = x.to(device)
        y = y.to(device)
        subset = torch.arange(x.size(0), device=edge_index.device)
        walks = random_walk(edge_index[0], edge_index[1], subset, walk_length,
                            p, q, x.size(0))
        # model forward
        y = y.repeat(walks.size(0), )
        y_pred = model(walks)

        pred = y_pred.max(dim=1)[1]
        correct = pred.eq(y).sum().item()
        correct /= y.size(0)
        batch_acc += (correct * 100)

    return batch_acc / len(dataset)
def train(model, dataset, opt, sch, loss_func, device):
    model.train()
    batch_loss = 0
    for graph in dataset:
        x, edge_index, y = graph.x, graph.edge_index, graph.y
        edge_index = edge_index.to(device)
        x = x.to(device)
        y = y.to(device)
        subset = torch.arange(x.size(0), device=edge_index.device)

        walks = random_walk(edge_index[0], edge_index[1], subset, walk_length,
                            p, q, x.size(0))
        if (walks.size(0) == 0):
            print('zero sized walks')

        nan_check_and_break(walks, "walks nan check")

        # model forward
        y = y.repeat(walks.size(0), )

        y_pred = model(walks)

        nan_check_and_break(y_pred, "y_pred")

        loss = loss_func(y_pred, y)

        print(loss.item())

        batch_loss += loss.item() / walks.size(0)

        opt.zero_grad()
        loss.backward()

        opt.step()
        sch.step()

    return batch_loss / len(dataset)
Exemple #12
0
def train(epoch):
    model.train()

    pbar = tqdm(total=data.num_nodes)
    pbar.set_description(f'Epoch {epoch:02d}')

    total_loss = 0

    node_idx = torch.randperm(data.num_nodes)  # shuffle all nodes
    train_loader = NeighborSampler(
        data.edge_index, node_idx=node_idx,
        sizes=num_samples, num_nodes=data.num_nodes,
        batch_size=batch_size, shuffle=False,
        num_workers=0)
    rw = random_walk(
        data.edge_index[0], data.edge_index[1],
        node_idx, walk_length=walk_length,
        num_nodes=data.num_nodes,
        coalesced=True)
    rw_idx = rw[:, 1:].flatten()
    # hack to bring out of bounds values in bounds. There is something wrong
    # with random_walk. No clue what effect this has on the model.
    rw_idx = torch.where(rw_idx > data.num_nodes - 1, rw_idx % data.num_nodes,
                         rw_idx)
    pos_loader = NeighborSampler(
        data.edge_index, node_idx=rw_idx,
        sizes=num_samples, num_nodes=data.num_nodes,
        batch_size=batch_size * walk_length, shuffle=False,
        num_workers=0)

    neg_loader = NeighborSampler(
        data.edge_index, node_idx=node_idx,
        sizes=num_samples, num_nodes=data.num_nodes,
        batch_size=batch_size, shuffle=True,
        num_workers=0)

    for (batch_size_, u_id, adjs_u), (_, v_id, adjs_v), (_, vn_id, adjs_vn) in\
            zip(train_loader, pos_loader, neg_loader):
        # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
        adjs_u = [adj.to(device) for adj in adjs_u]
        z_u = model(x[u_id], adjs_u)

        adjs_v = [adj.to(device) for adj in adjs_v]
        z_v = model(x[v_id], adjs_v)

        adjs_vn = [adj.to(device) for adj in adjs_vn]
        z_vn = model(x[vn_id], adjs_vn)

        optimizer.zero_grad()
        pos_loss = -F.logsigmoid(
            (z_u.repeat_interleave(walk_length, dim=0)*z_v).sum(dim=1)).mean()
        neg_loss = -F.logsigmoid(
            -(z_u*z_vn).sum(dim=1)).mean()
        loss = pos_loss + neg_loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.update(batch_size_)

    pbar.close()

    loss = total_loss / len(train_loader)

    return loss
    def subsample_walk(self, sample_args=None):
        r"""Takes a subsample of the graph by performing random walks at
        various randomly selected vertices of particular lengths.

        Args:
            sample_args (dict) : Dictionary with the following keys/items:
                start_num (int) : Decides the number of randomly selected
                    nodes to begin random walks at.
                walk_length (int) : Decides the length of each random walk
                    to perform.

        Returns:
            subsample_dict (dict) : A dictionary with the following three keys:
                'node_ind' : The indices of the nodes which are part of the
                    subsampled graph.
                'edge_ind' : The indices of the edges which are part of the
                    subsampled graph, with respect to the ordering given in
                    self.edge_index and self.edge_weight.
                'edge_list' : Pairs of edges which are part of the subsampled
                    graph, stored as a Tensor of shape
                    [2, num of subsampled edges].
        """
        # Extract values from sample_dict
        start_num = sample_args['start_num']
        walk_length = sample_args['walk_length']

        # Choose random starting locations
        start_loc = torch.as_tensor(np.random.choice(
            self.nodes_pos_degree_mask.numpy(), size=start_num, replace=False),
                                    dtype=torch.long)

        # Perform random walks; note that walks has shape
        # [start_num, walk_length + 1]
        walks = random_walk(row=self.edge_index[0, self.edges_pos_degree_mask],
                            col=self.edge_index[1, self.edges_pos_degree_mask],
                            start=start_loc,
                            num_nodes=self.num_nodes,
                            walk_length=walk_length,
                            coalesced=True)

        # Convert into edge list
        edge_list = torch.zeros(size=(2, start_num * walk_length),
                                dtype=torch.long)

        for i in range(start_num):
            ind = i * walk_length + np.arange(walk_length)
            ind = list(ind)
            edge_list[0, ind] = walks[i, :-1]
            edge_list[1, ind] = walks[i, 1:]

        # Get rid of duplicates in edge list
        edge_list = torch.unique(edge_list, dim=1)

        # Extract nodes and the indices of edge_list with respect to
        # self.edge_index
        node_ind = torch.unique(edge_list)
        edge_ind = torch.zeros(size=[edge_list.shape[1]], dtype=torch.long)

        for i in range(edge_list.shape[1]):
            edge_ind[i] = torch.where(
                (self.edge_index[0, :] == edge_list[0, i])
                & (self.edge_index[1, :] == edge_list[1, i]))[0][0]

        subsample_dict = {
            'node_ind': node_ind,
            'edge_ind': edge_ind,
            'edge_list': edge_list
        }

        return subsample_dict
Exemple #14
0
    def sample(self, batch):
        batch = torch.tensor(batch)
        row, col, _ = self.adj_t.coo()

        # For each node in `batch`, we sample a direct neighbor (as positive
        # example) and a random node (as negative example):
        pos_batch = random_walk(row, col, batch, walk_length=2,
                                coalesced=False)[:, 1]

        # print("batch:", batch)

        weighted_adj = SparseTensor(row=row, col=col, value=self.edge_weights.cpu())

        # print(len(self.edge_weights))
        # print(len(row))

        # print("adj:", self.adj_t)
        # print("weighted:", weighted_adj)

        connections = (weighted_adj.to_dense()).numpy() # [:self.num_doc_nodes]
        # print("connections:", connections)
        # print("connections of first doc:", connections[0])
        # print("batch:", batch.numpy())
        # print("batch sums:", connections[batch.numpy()])


        # Quick non-optimised implementation since this can be run in parralel anyway
        friends = []
        for doc in batch:
            doc_number = doc.numpy()

            word_conns = connections[doc_number][self.num_doc_nodes:]

            # print("doc #:", doc_number)
            # print("connections:", word_conns, "(%i in total)" % sum(word_conns))

            # these where's can now be removed now we use p I know ok
            mask = np.where(word_conns > 0)[0]
            if len(mask) == 0:
                random_doc_back = np.random.randint(0, self.num_doc_nodes) # if the word is exclusive to that doc choose a random doc, this should not happen often
                friends.append(random_doc_back)
                continue

            p = word_conns[mask] / sum(word_conns[mask])
            random_word = np.random.choice(mask, p=p) + self.num_doc_nodes
            # print("chosen word:", random_word)

            rwords_conns = connections[random_word]

            rwords_conns_to_docs = rwords_conns[:self.num_doc_nodes]
            # print("That word is connected to the docs:", rwords_conns_to_docs, "(%i in total)" % sum(rwords_conns_to_docs))

            mask = np.where(rwords_conns_to_docs > 0)[0]
            mask = mask[mask != doc_number]
            p = rwords_conns_to_docs[mask] / sum(rwords_conns_to_docs[mask])
            
            if len(mask) == 0:
                random_doc_back = np.random.randint(0, self.num_doc_nodes) # if the word is exclusive to that doc choose a random doc, this should not happen often
            else:
                random_doc_back = np.random.choice(mask, p=p)

            friends.append(random_word)
            # print("%i CHOSEN DOC FRIEND:" % doc_number, random_doc_back)

            # my_conns = connections[doc_number]
            # her_conns = connections[random_doc_back]

            # print("We have %i words in common" % (sum((my_conns > 0) & (her_conns > 0))))

            # print("[", end="")
            # for j in range(self.num_doc_nodes):
            #     their_conns = sum((my_conns > 0) & (connections[j] > 0))
            #     their_conns = sum((my_conns > 0) & (connections[j] > 0))
            #     print(their_conns, end=", ")

        pos_batch = torch.tensor(friends, dtype=torch.long)

        neg_batch = torch.randint(self.num_doc_nodes, self.adj_t.size(1), (batch.numel(), ),
                                  dtype=torch.long)
        # neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ),
        #                           dtype=torch.long)

        batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
        return super(PosNegNeighborSampler, self).sample(batch)