예제 #1
0
파일: attconv.py 프로젝트: xnhp/GraphGym
    def norm(edge_index,
             num_nodes,
             edge_weight=None,
             improved=False,
             dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)

        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
예제 #2
0
파일: DGCN.py 프로젝트: matthew-hirn/magnet
def gcn_norm(edge_index,
             edge_weight=None,
             num_nodes=None,
             improved=False,
             add_self_loops=True,
             dtype=None):

    fill_value = 2. if improved else 1.

    if isinstance(edge_index, SparseTensor):
        adj_t = edge_index
        if not adj_t.has_value():
            adj_t = adj_t.fill_value(1., dtype=dtype)
        if add_self_loops:
            adj_t = fill_diag(adj_t, fill_value)
        deg = sum(adj_t, dim=1)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
        adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
        adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
        return adj_t

    else:
        num_nodes = maybe_num_nodes(edge_index, num_nodes)

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)

        if add_self_loops:
            edge_index, tmp_edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value, num_nodes)
            assert tmp_edge_weight is not None
            edge_weight = tmp_edge_weight

        row, col = edge_index[0], edge_index[1]
        deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def compute_identity(edge_index, n, k):
    edge_weight = torch.ones((edge_index.size(1), ),
                             dtype=torch.float,
                             device=edge_index.device)
    edge_index, edge_weight = pyg_utils.add_remaining_self_loops(
        edge_index, edge_weight, 1, n)
    adj_sparse = torch.sparse.FloatTensor(edge_index, edge_weight,
                                          torch.Size([n, n]))
    adj = adj_sparse.to_dense()

    deg = torch.diag(torch.sum(adj, -1))
    deg_inv_sqrt = deg.pow(-0.5)
    adj = deg_inv_sqrt @ adj @ deg_inv_sqrt

    diag_all = [torch.diag(adj)]
    adj_power = adj
    for i in range(1, k):
        adj_power = adj_power @ adj
        diag_all.append(torch.diag(adj_power))
    diag_all = torch.stack(diag_all, dim=1)
    return diag_all
예제 #4
0
    def forward(self, data):
        x, edge_index, y, batch = data.x, data.edge_index, data.y, data.batch
        edge_index, _ = add_remaining_self_loops(edge_index,
                                                 num_nodes=x.shape[0])

        edge_list = []
        perm_list = []
        shape_list = []
        edge_weight = None

        f, e, b = x, edge_index, batch
        for i in range(self.depth):
            if i < self.depth:
                edge_list.append(e)
            f, attn = self.down_list[i](f, e, self.direction)
            shape_list.append(f.shape)
            f = F.leaky_relu(f)
            f, e, _, b, perm, _ = self.pool_list[i](f, e, edge_weight, b, attn)
            if i < self.depth - 1:
                e, _ = self.augment_adj(e, None, f.shape[0])
            perm_list.append(perm)
        latent_x, latent_edge = f, e

        z = f
        for i in range(self.depth):
            index = self.depth - i - 1
            shape = shape_list[index]
            up = torch.zeros(shape).to(self.device)
            p = perm_list[index]
            up[p] = z
            z = self.up_list[i](up, edge_list[index])
            if i < self.depth - 1:
                z = torch.relu(z)

        edge_list.clear()
        perm_list.clear()
        shape_list.clear()

        return z, latent_x, latent_edge, b
예제 #5
0
    def forward(self, x, edge_index):
        edge_index, _ = add_remaining_self_loops(edge_index)

        row, col = edge_index
        deg = degree(row)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        if self.norm == 'neighbornorm':
            x_j = self.normlayer(x, edge_index)
        else:
            x_j = x[col]

        x_j = norm.view(-1, 1) * x_j
        out = scatter_add(src=x_j, index=row, dim=0, dim_size=x.size(0))

        if self.norm == 'neighbornorm':
            out = F.relu(self.linear(out))
        else:
            out = self.normlayer(F.relu(self.linear(out)))

        return out
예제 #6
0
    def forward(self, x, edge_index, edge_attr, edge_weight=None):
        x = torch.matmul(x, self.weight)

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=x.dtype,
                                     device=edge_index.device)

        fill_value = 1 if not self.improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, x.size(0))

        self_loop_edges = torch.zeros(x.size(0),
                                      edge_attr.size(1)).to(edge_index.device)
        edge_attr = torch.cat([edge_attr, self_loop_edges], dim=0)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=x.size(0))
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
        return self.propagate(edge_index, x=x, edge_attr=edge_attr, norm=norm)
예제 #7
0
    def forward(self,
                x,
                edge_index,
                edge_weight=None,
                size=None,
                res_n_id=None):
        """
        Args:
            res_n_id (Tensor, optional): Residual node indices coming from
                :obj:`DataFlow` generated by :obj:`NeighborSampler` are used to
                select central node features in :obj:`x`.
                Required if operating in a bipartite graph and :obj:`concat` is
                :obj:`True`. (default: :obj:`None`)
        """
        if not self.concat and torch.is_tensor(x):
            edge_index, edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, 1, x.size(self.node_dim))

        return self.propagate(edge_index,
                              size=size,
                              x=x,
                              edge_weight=edge_weight,
                              res_n_id=res_n_id)
예제 #8
0
    def forward(self, x, edge_index, edge_weight=None, pseudo=None, size=None):
        """"""
        edge_weight = edge_weight.squeeze()
        if size is None and torch.is_tensor(x):
            edge_index, edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, 1, x.size(0))

        weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels)
        if torch.is_tensor(x):
            x = torch.matmul(x.unsqueeze(1), weight).squeeze(1)
        else:
            x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1),
                 None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1))

        # weight = self.nn(pseudo).view(-1, self.out_channels,self.in_channels)
        # if torch.is_tensor(x):
        #     x = torch.matmul(x.unsqueeze(1), weight.permute(0,2,1)).squeeze(1)
        # else:
        #     x = (None if x[0] is None else torch.matmul(x[0].unsqueeze(1), weight).squeeze(1),
        #          None if x[1] is None else torch.matmul(x[1].unsqueeze(1), weight).squeeze(1))

        return self.propagate(edge_index, size=size, x=x,
                              edge_weight=edge_weight)
예제 #9
0
def graph_connectivity(device, perm, edge_index, edge_weight, score, ratio,
                       batch, N):
    r"""graph_connectivity: is a function which internally calls StAS func to maintain graph connectivity"""

    kN = perm.size(0)
    perm2 = perm.view(-1, 1)

    # mask contains bool mask of edges which originate from perm (selected) nodes
    mask = (edge_index[0] == perm2).sum(0, dtype=torch.bool)

    # create the S
    S0 = edge_index[1][mask].view(1, -1)
    S1 = edge_index[0][mask].view(1, -1)
    index_S = torch.cat([S0, S1], dim=0)
    value_S = score[mask].detach().squeeze()

    # relabel for pooling ie: make S [N x kN]
    n_idx = torch.zeros(N, dtype=torch.long)
    n_idx[perm] = torch.arange(perm.size(0))
    index_S[1] = n_idx[index_S[1]]

    # create A
    index_A = edge_index.clone()
    if edge_weight is None:
        value_A = value_S.new_ones(edge_index[0].size(0))
    else:
        value_A = edge_weight.clone()

    fill_value = 1
    index_E, value_E = StAS(index_A, value_A, index_S, value_S, device, N, kN)
    index_E, value_E = remove_self_loops(edge_index=index_E, edge_attr=value_E)
    index_E, value_E = add_remaining_self_loops(edge_index=index_E,
                                                edge_weight=value_E,
                                                fill_value=fill_value,
                                                num_nodes=kN)

    return index_E, value_E
예제 #10
0
def load_cora():
    edges = pd.read_csv(CORA + 'cora_cites.csv')
    data = pd.read_csv(CORA + 'cora_content.csv')

    id_to_node = dict([(row['paper_id'], idx) for idx, row in data.iterrows()])
    class_to_int = dict([(c, i) for i, c in enumerate(set(data['label']))])

    # COO matrix of edges converted to node ids to match the
    # feature tensor
    citing = [id_to_node[e] for e in edges['citing_paper_id']]
    cited = [id_to_node[e] for e in edges['cited_paper_id']]

    # Undirected since there are so many orphans otherwise
    ei = torch.tensor([
        citing,  # + cited,
        cited,  # + citing
    ])

    ei = add_remaining_self_loops(ei)[0]

    # Don't need paper id's or class in node attr vectors
    X = torch.tensor(data.iloc[:, 1:-1].values, dtype=torch.float)

    y = torch.zeros(X.size()[0], len(class_to_int))
    i = 0
    for c in data['label']:
        y[i][class_to_int[c]] = 1
        i += 1

    weights = y.sum(dim=0)
    weights = weights.max() / weights

    return Data(x=X,
                edge_index=ei,
                y=y,
                weights=weights,
                num_nodes=X.size()[0])
예제 #11
0
    def norm(edge_index,
             num_nodes,
             edge_weight=None,
             improved=False,
             dtype=None):

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)
        # edge_index는 [2, num_edge] 형태, 순서대로 row, column index가 됨
        # edge_weight는 num_edges만큼

        fill_value = 1 if not improved else 2
        # self loop 1 더해줌
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)
        row, col = edge_index
        # degree sum of edge weights
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        # result still in COO form
        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
예제 #12
0
    def diag_enhance_norm(edge_index,
                          num_nodes,
                          edge_weight=None,
                          improved=False,
                          dtype=None,
                          diag_lambda=1.0):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)

        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm_edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
        diag_edge_weight = norm_edge_weight.clone()
        diag_edge_weight[edge_index[0] != edge_index[1]] = 0

        return (edge_index, norm_edge_weight + diag_lambda * diag_edge_weight)
예제 #13
0
    def My_norms(self,
                 x_norm,
                 edge_index,
                 num_nodes,
                 edge_weight=None,
                 improved=False,
                 dtype=None):

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)

        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)

        edge_index_j = edge_index[0]
        edge_index_i = edge_index[1]
        x_norm_j = x_norm[edge_index_j]
        x_norm_i = x_norm[edge_index_i]
        alpha = self.beta * (x_norm_i * x_norm_j).sum(dim=-1)
        alpha = softmax(alpha, edge_index_i, num_nodes)
        return edge_index, alpha
예제 #14
0
    def __dropout_adj__(self, sparse_adj: SparseTensor,
                        dropout_adj_prob: float):
        # number of nodes
        N = sparse_adj.size(0)
        # sparse adj matrix to dense adj matrix
        row, col, edge_attr = sparse_adj.coo()
        edge_index = torch.stack([row, col], dim=0)
        # dropout adjacency matrix -> generalization
        edge_index, edge_attr = dropout_adj(edge_index,
                                            edge_attr=edge_attr,
                                            p=dropout_adj_prob,
                                            force_undirected=True,
                                            training=self.training)
        # because dropout removes self-loops (due to force_undirected=True), make sure to add them back again
        edge_index, edge_attr = add_remaining_self_loops(edge_index,
                                                         edge_weight=edge_attr,
                                                         fill_value=0.00,
                                                         num_nodes=N)
        # dense adj matrix to sparse adj matrix
        sparse_adj = SparseTensor.from_edge_index(edge_index,
                                                  edge_attr=edge_attr,
                                                  sparse_sizes=(N, N))

        return sparse_adj
예제 #15
0
파일: layers.py 프로젝트: hujilin1229/GMI
    def norm(edge_index, num_nodes, edge_weight, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, 0, num_nodes)

        row, col = edge_index
        expand_deg = torch.zeros((edge_weight.size(0), ),
                                 dtype=dtype,
                                 device=edge_index.device)
        expand_deg[-num_nodes:] = torch.ones((num_nodes, ),
                                             dtype=dtype,
                                             device=edge_index.device)

        return edge_index, expand_deg - deg_inv_sqrt[
            row] * edge_weight * deg_inv_sqrt[col]
예제 #16
0
    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
        """"""
        symnorm_weight: OptTensor = None
        if "symnorm" in self.aggregators:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, symnorm_weight = gcn_norm(  # yapf: disable
                        edge_index,
                        None,
                        num_nodes=x.size(self.node_dim),
                        improved=False,
                        add_self_loops=self.add_self_loops)
                    if self.cached:
                        self._cached_edge_index = (edge_index, symnorm_weight)
                else:
                    edge_index, symnorm_weight = cache

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index,
                        None,
                        num_nodes=x.size(self.node_dim),
                        improved=False,
                        add_self_loops=self.add_self_loops)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        elif self.add_self_loops:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if self.cached and cache is not None:
                    edge_index = cache[0]
                else:
                    edge_index, _ = add_remaining_self_loops(edge_index)
                    if self.cached:
                        self._cached_edge_index = (edge_index, None)

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if self.cached and cache is not None:
                    edge_index = cache
                else:
                    edge_index = fill_diag(edge_index, 1.0)
                    if self.cached:
                        self._cached_adj_t = edge_index

        # [num_nodes, (out_channels // num_heads) * num_bases]
        bases = self.bases_lin(x)
        # [num_nodes, num_heads * num_bases * num_aggrs]
        weightings = self.comb_lin(x)

        # [num_nodes, num_aggregators, (out_channels // num_heads) * num_bases]
        # propagate_type: (x: Tensor, symnorm_weight: OptTensor)
        aggregated = self.propagate(edge_index,
                                    x=bases,
                                    symnorm_weight=symnorm_weight,
                                    size=None)

        weightings = weightings.view(-1, self.num_heads,
                                     self.num_bases * len(self.aggregators))
        aggregated = aggregated.view(
            -1,
            len(self.aggregators) * self.num_bases,
            self.out_channels // self.num_heads,
        )

        # [num_nodes, num_heads, out_channels // num_heads]
        out = torch.matmul(weightings, aggregated)
        out = out.view(-1, self.out_channels)

        if self.bias is not None:
            out += self.bias

        return out
예제 #17
0
    def norm(self,
             edge_index,
             num_nodes,
             edge_weight=None,
             improved=False,
             dtype=None):

        adj_dict = {}

        def add_edge(a, b):
            if a in adj_dict:
                neighbors = adj_dict[a]
            else:
                neighbors = set()
                adj_dict[a] = neighbors
            if b not in neighbors:
                neighbors.add(b)

        cpu_device = torch.device("cpu")
        gpu_device = torch.device("cuda")
        for a, b in edge_index.t().detach().to(cpu_device).numpy():
            a = int(a)
            b = int(b)
            add_edge(a, b)
            add_edge(b, a)

        adj_dict = {a: list(neighbors) for a, neighbors in adj_dict.items()}

        def sample_neighbor(a):
            neighbors = adj_dict[a]
            random_index = np.random.randint(0, len(neighbors))
            return neighbors[random_index]

        # word_counter = Counter()
        walk_counters = {}

        def norm(counter):
            s = sum(counter.values())
            new_counter = Counter()
            for a, count in counter.items():
                new_counter[a] = counter[a] / s
            return new_counter

        for _ in tqdm(range(40)):
            for a in adj_dict:
                current_a = a
                current_path_len = np.random.randint(1, self.path_len + 1)
                for _ in range(current_path_len):
                    b = sample_neighbor(current_a)
                    if a in walk_counters:
                        walk_counter = walk_counters[a]
                    else:
                        walk_counter = Counter()
                        walk_counters[a] = walk_counter

                    walk_counter[b] += 1

                    current_a = b

        normed_walk_counters = {
            a: norm(walk_counter)
            for a, walk_counter in walk_counters.items()
        }

        prob_sums = Counter()

        for a, normed_walk_counter in normed_walk_counters.items():
            for b, prob in normed_walk_counter.items():
                prob_sums[b] += prob

        ppmis = {}

        for a, normed_walk_counter in normed_walk_counters.items():
            for b, prob in normed_walk_counter.items():
                ppmi = np.log(prob / prob_sums[b] * len(prob_sums) /
                              self.path_len)
                ppmis[(a, b)] = ppmi

        new_edge_index = []
        edge_weight = []
        for (a, b), ppmi in ppmis.items():
            new_edge_index.append([a, b])
            edge_weight.append(ppmi)

        edge_index = torch.tensor(new_edge_index).t().to(gpu_device)
        edge_weight = torch.tensor(edge_weight).to(gpu_device)

        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, (deg_inv_sqrt[row] * edge_weight *
                            deg_inv_sqrt[col]).type(torch.float32)
예제 #18
0
def graph_max_pool(x, edge_index):
    edge_index, _ = add_remaining_self_loops(edge_index)
    source = edge_index[0]
    dest = edge_index[1]
    return scatter_('max', x[dest], source, dim_size=len(x))
예제 #19
0
    def forward(self, data, edge_dropout=None, penalty_coefficient=0.25):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        num_graphs = batch.max().item() + 1
        row, col = edge_index
        total_num_edges = edge_index.shape[1]
        N_size = x.shape[0]

        if edge_dropout is not None:
            edge_index = dropout_adj(
                edge_index,
                edge_attr=(torch.ones(edge_index.shape[1],
                                      device=device)).long(),
                p=edge_dropout,
                force_undirected=True)[0]
            edge_index = add_remaining_self_loops(edge_index,
                                                  num_nodes=batch.shape[0])[0]

        reduced_num_edges = edge_index.shape[1]
        current_edge_percentage = (reduced_num_edges / total_num_edges)
        no_loop_index, _ = remove_self_loops(edge_index)
        no_loop_row, no_loop_col = no_loop_index

        xinit = x.clone()
        x = x.unsqueeze(-1)
        mask = get_mask(x, edge_index, 1).to(x.dtype)
        x = F.leaky_relu(self.conv1(x, edge_index))  # +x
        x = x * mask
        x = self.gnorm(x)
        x = self.bn1(x)

        for conv, bn in zip(self.convs, self.bns):
            if (x.dim() > 1):
                x = x + F.leaky_relu(conv(x, edge_index))
                mask = get_mask(mask, edge_index, 1).to(x.dtype)
                x = x * mask
                x = self.gnorm(x)
                x = bn(x)

        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x))
        x = x * mask

        xpostlin1 = x.detach()
        x = F.leaky_relu(self.lin2(x))
        x = x * mask

        #calculate min and max
        batch_max = scatter_max(x, batch, 0, dim_size=N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)
        batch_min = scatter_min(x, batch, 0, dim_size=N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)

        #min-max normalize
        x = (x - batch_min) / (batch_max + 1e-6 - batch_min)
        probs = x

        #calculating the terms for the expected distance between clique and graph
        pairwise_prodsums = torch.zeros(num_graphs, device=device)
        for graph in range(num_graphs):
            batch_graph = (batch == graph)
            pairwise_prodsums[graph] = (torch.conv1d(
                probs[batch_graph].unsqueeze(-1),
                probs[batch_graph].unsqueeze(-1))).sum() / 2

        ###calculate loss terms
        self_sums = scatter_add((probs * probs), batch, 0, dim_size=num_graphs)
        expected_weight_G = scatter_add(
            probs[no_loop_row] * probs[no_loop_col],
            batch[no_loop_row],
            0,
            dim_size=num_graphs) / 2.
        expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) -
                                  self_sums) / 1.
        expected_distance = (expected_clique_weight - expected_weight_G)

        ###calculate loss
        expected_loss = (penalty_coefficient
                         ) * expected_distance * 0.5 - 0.5 * expected_weight_G

        loss = expected_loss

        retdict = {}

        retdict["output"] = [probs.squeeze(-1), "hist"]  #output
        retdict["losses histogram"] = [loss.squeeze(-1), "hist"]
        retdict["Expected weight(G)"] = [expected_weight_G.mean(), "sequence"]
        retdict["Expected maximum weight"] = [
            expected_clique_weight.mean(), "sequence"
        ]
        retdict["Expected distance"] = [expected_distance.mean(), "sequence"]
        retdict["loss"] = [loss.mean().squeeze(), "sequence"]  #final loss

        return retdict
예제 #20
0
def main():
    global device
    global graphname

    print(socket.gethostname())
    seed = 0

    if not download:
        mp.set_start_method('spawn', force=True)
        outputs = None
        if "OMPI_COMM_WORLD_RANK" in os.environ.keys():
            os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
        # Initialize distributed environment with SLURM
        if "SLURM_PROCID" in os.environ.keys():
            os.environ["RANK"] = os.environ["SLURM_PROCID"]

        if "SLURM_NTASKS" in os.environ.keys():
            os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"]

        if "MASTER_ADDR" not in os.environ.keys():
            os.environ["MASTER_ADDR"] = "127.0.0.1"

        os.environ["MASTER_PORT"] = "1234"
        dist.init_process_group(backend='nccl')
        rank = dist.get_rank()
        size = dist.get_world_size()
        print("Processes: " + str(size))

        # device = torch.device('cpu')
        devid = rank_to_devid(rank, acc_per_rank)
        device = torch.device('cuda:{}'.format(devid))
        torch.cuda.set_device(device)
        curr_devid = torch.cuda.current_device()
        # print(f"curr_devid: {curr_devid}", flush=True)
        devcount = torch.cuda.device_count()

    if graphname == "Cora":
        path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', graphname)
        dataset = Planetoid(path, graphname, T.NormalizeFeatures())
        data = dataset[0]
        data = data.to(device)
        data.x.requires_grad = True
        inputs = data.x.to(device)
        inputs.requires_grad = True
        data.y = data.y.to(device)
        edge_index = data.edge_index
        num_features = dataset.num_features
        num_classes = dataset.num_classes
    elif graphname == "Reddit":
        path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', graphname)
        dataset = Reddit(path, T.NormalizeFeatures())
        data = dataset[0]
        data = data.to(device)
        data.x.requires_grad = True
        inputs = data.x.to(device)
        inputs.requires_grad = True
        data.y = data.y.to(device)
        edge_index = data.edge_index
        num_features = dataset.num_features
        num_classes = dataset.num_classes
    elif graphname == 'Amazon':
        # path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', graphname)
        # edge_index = torch.load(path + "/processed/amazon_graph.pt")
        # edge_index = torch.load("/gpfs/alpine/bif115/scratch/alokt/Amazon/processed/amazon_graph_jsongz.pt")
        # edge_index = edge_index.t_()
        print(f"Loading coo...", flush=True)
        edge_index = torch.load("../data/Amazon/processed/data.pt")
        print(f"Done loading coo", flush=True)
        # n = 9430088
        n = 14249639
        num_features = 300
        num_classes = 24
        # mid_layer = 24
        inputs = torch.rand(n, num_features)
        data = Data()
        data.y = torch.rand(n).uniform_(0, num_classes - 1).long()
        data.train_mask = torch.ones(n).long()
        # edge_index = edge_index.to(device)
        print(f"edge_index.size: {edge_index.size()}", flush=True)
        print(f"edge_index: {edge_index}", flush=True)
        data = data.to(device)
        # inputs = inputs.to(device)
        inputs.requires_grad = True
        data.y = data.y.to(device)
    elif graphname == 'subgraph3':
        # path = "/gpfs/alpine/bif115/scratch/alokt/HipMCL/"
        # print(f"Loading coo...", flush=True)
        # edge_index = torch.load(path + "/processed/subgraph3_graph.pt")
        # print(f"Done loading coo", flush=True)
        print(f"Loading coo...", flush=True)
        edge_index = torch.load("../data/subgraph3/processed/data.pt")
        print(f"Done loading coo", flush=True)
        n = 8745542
        num_features = 128
        # mid_layer = 512
        # mid_layer = 64
        num_classes = 256
        inputs = torch.rand(n, num_features)
        data = Data()
        data.y = torch.rand(n).uniform_(0, num_classes - 1).long()
        data.train_mask = torch.ones(n).long()
        print(f"edge_index.size: {edge_index.size()}", flush=True)
        data = data.to(device)
        inputs.requires_grad = True
        data.y = data.y.to(device)

    if download:
        exit()

    if normalization:
        adj_matrix, _ = add_remaining_self_loops(edge_index, num_nodes=inputs.size(0))
    else:
        adj_matrix = edge_index


    init_process(rank, size, inputs, adj_matrix, data, num_features, num_classes, device, outputs, 
                    run)

    if outputs is not None:
        return outputs[0]
예제 #21
0
파일: layers.py 프로젝트: hujilin1229/GMI
    def forward(self, x, edge_index, edge_attr, batch=None):
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        # replace with MI
        x_information_score = self.calc_information_score(
            x, edge_index, edge_attr)
        score = torch.sum(torch.abs(x_information_score), dim=1)

        # Graph Pooling
        original_x = x
        perm = topk(score, self.ratio, batch)
        x = x[perm]
        batch = batch[perm]
        induced_edge_index, induced_edge_attr = filter_adj(
            edge_index, edge_attr, perm, num_nodes=score.size(0))

        # Discard structure learning layer, directly return
        if self.sl is False:
            return x, induced_edge_index, induced_edge_attr, batch

        # Structure Learning
        if self.sample:
            # A fast mode for large graphs.
            # In large graphs, learning the possible edge weights between each pair of nodes is time consuming.
            # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the
            # edge weights between them.
            k_hop = 3
            if edge_attr is None:
                edge_attr = torch.ones((edge_index.size(1), ),
                                       dtype=torch.float,
                                       device=edge_index.device)

            hop_data = Data(x=original_x,
                            edge_index=edge_index,
                            edge_attr=edge_attr)
            for _ in range(k_hop - 1):
                hop_data = self.neighbor_augment(hop_data)
            hop_edge_index = hop_data.edge_index
            hop_edge_attr = hop_data.edge_attr
            new_edge_index, new_edge_attr = filter_adj(hop_edge_index,
                                                       hop_edge_attr,
                                                       perm,
                                                       num_nodes=score.size(0))

            new_edge_index, new_edge_attr = add_remaining_self_loops(
                new_edge_index, new_edge_attr, 0, x.size(0))
            row, col = new_edge_index
            weights = (torch.cat([x[row], x[col]], dim=1) *
                       self.att).sum(dim=-1)
            weights = F.leaky_relu(
                weights, self.negative_slop) + new_edge_attr * self.lamb
            adj = torch.zeros((x.size(0), x.size(0)),
                              dtype=torch.float,
                              device=x.device)
            adj[row, col] = weights
            new_edge_index, weights = dense_to_sparse(adj)
            row, col = new_edge_index
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()
        else:
            # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower.
            if edge_attr is None:
                induced_edge_attr = torch.ones(
                    (induced_edge_index.size(1), ),
                    dtype=x.dtype,
                    device=induced_edge_index.device)
            num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
            shift_cum_num_nodes = torch.cat(
                [num_nodes.new_zeros(1),
                 num_nodes.cumsum(dim=0)[:-1]], dim=0)
            cum_num_nodes = num_nodes.cumsum(dim=0)
            adj = torch.zeros((x.size(0), x.size(0)),
                              dtype=torch.float,
                              device=x.device)
            # Construct batch fully connected graph in block diagonal matirx format
            for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes):
                adj[idx_i:idx_j, idx_i:idx_j] = 1.0
            new_edge_index, _ = dense_to_sparse(adj)
            row, col = new_edge_index

            weights = (torch.cat([x[row], x[col]], dim=1) *
                       self.att).sum(dim=-1)
            weights = F.leaky_relu(weights, self.negative_slop)
            adj[row, col] = weights
            induced_row, induced_col = induced_edge_index

            adj[induced_row, induced_col] += induced_edge_attr * self.lamb
            weights = adj[row, col]
            if self.sparse:
                new_edge_attr = self.sparse_attention(weights, row)
            else:
                new_edge_attr = softmax(weights, row, x.size(0))
            # filter out zero weight edges
            adj[row, col] = new_edge_attr
            new_edge_index, new_edge_attr = dense_to_sparse(adj)
            # release gpu memory
            del adj
            torch.cuda.empty_cache()

        return x, new_edge_index, new_edge_attr, batch
예제 #22
0
    def forward(self,
                x,
                pos_edge_index,
                neg_edge_index,
                return_attention_weights=True):
        """"""
        # hyper linear
        pos_edge_index = add_remaining_self_loops(pos_edge_index,
                                                  num_nodes=x.size(0))[0]

        x = self.manifolds.proj(self.manifolds.expmap0(
            self.manifolds.proj_tan0(x, self.c), c=self.c),
                                c=self.c)
        if self.manifolds.name != 'PoincareBall':
            drop_weight = F.dropout(self.weight,
                                    self.dropout,
                                    training=self.training)
            mv = self.manifolds.mobius_matvec(drop_weight, x, self.c)
            res = self.manifolds.proj(mv, self.c)
        else:
            res = x
        if torch.isnan(res).any():
            print("check here")
        # assert not torch.isnan(res).any()
        if self.use_bias:
            bias = self.manifolds.proj_tan0(self.bias.view(1, -1), self.c)
            hyp_bias = self.manifolds.expmap0(bias, self.c)
            hyp_bias = self.manifolds.proj(hyp_bias, self.c)
            res = self.manifolds.mobius_add(res, hyp_bias, c=self.c)
            res = self.manifolds.proj(res, self.c)
        torch.cuda.empty_cache()
        x = (self.manifolds.logmap0(res, c=self.c)).cuda() + 1e-15

        if self.first_aggr:
            if self.manifolds.name == 'Hyperboloid':
                assert x.size(1) == self.in_channels - 1
            else:
                assert x.size(1) == self.in_channels

            if return_attention_weights:
                x_trans_pos = (self.lin_pos_agg(x), self.lin_pos_agg(x))
                x_trans_neg = (self.lin_neg_agg(x), self.lin_neg_agg(x))
            else:
                x_trans_pos = x
                x_trans_neg = x

            x_pos = torch.cat([
                self.propagate(
                    pos_edge_index,
                    x=x_trans_pos,
                    size=None,
                    return_attention_weights=return_attention_weights), x
            ],
                              dim=1)
            x_neg = torch.cat([
                self.propagate(
                    neg_edge_index,
                    x=x_trans_neg,
                    size=None,
                    return_attention_weights=return_attention_weights), x
            ],
                              dim=1)

        else:
            assert x.size(1) == 2 * self.in_channels

            x_1, x_2 = x[:, :self.in_channels], x[:, self.in_channels:]

            x_pos = torch.cat([
                self.propagate(
                    pos_edge_index,
                    x=(self.lin_pos_agg(x_1), self.lin_pos_agg(x_1)),
                    size=None,
                    return_attention_weights=return_attention_weights),
                self.propagate(
                    neg_edge_index,
                    x=(self.lin_neg_agg(x_2), self.lin_neg_agg(x_2)),
                    size=None,
                    return_attention_weights=return_attention_weights),
                x_1,
            ],
                              dim=1)

            x_neg = torch.cat([
                self.propagate(
                    pos_edge_index,
                    x=(self.lin_pos_agg(x_2), self.lin_pos_agg(x_2)),
                    size=None,
                    return_attention_weights=return_attention_weights),
                self.propagate(
                    neg_edge_index,
                    x=(self.lin_neg_agg(x_1), self.lin_neg_agg(x_1)),
                    size=None,
                    return_attention_weights=return_attention_weights),
                x_2,
            ],
                              dim=1)
        # to ensure numetrical stable
        x_pos = x_pos + 1e-15
        x_neg = x_neg + 1e-15
        assert not torch.isnan(x_pos).any()
        assert not torch.isnan(x_neg).any()
        x_pos = self.manifolds.proj(self.manifolds.expmap0(self.lin_pos(x_pos),
                                                           c=self.c),
                                    c=self.c)
        x_neg = self.manifolds.proj(self.manifolds.expmap0(self.lin_neg(x_neg),
                                                           c=self.c),
                                    c=self.c)

        x_out = torch.cat([x_pos, x_neg], dim=1)

        xt = self.act(self.manifolds.logmap0(x_out, c=self.c),
                      self.negative_slope)
        xt = self.manifolds.proj_tan0(xt, c=self.c)
        xt = self.manifolds.proj(self.manifolds.expmap0(xt, c=self.c),
                                 c=self.c)
        if torch.isnan(xt).any():
            print("check here")
        assert not torch.isnan(xt).any()

        return xt
예제 #23
0
        mean = torch.mean(mean, dim=-1, keepdim=True)
        var = scatter_mean((x[col] - mean[row])**2,
                           row,
                           dim=0,
                           dim_size=x.size(0))
        var = torch.mean(var, dim=-1, keepdim=True)
        # std = scatter_std(x[col], row, dim=0, dim_size=x.size(0))
        out = (x[col] - mean[row]) / (var[row] + self.eps).sqrt()
        # out = (x[col] - mean[row]) / (std[row]**2 + self.eps).sqrt()
        out = self.gamma * out + self.beta
        return out


if __name__ == '__main__':
    from torch_geometric.data import Data
    from torch_geometric.utils import add_remaining_self_loops

    edge_index = torch.tensor(
        [[0, 1], [1, 0], [1, 2], [2, 1], [2, 3], [2, 4], [3, 2], [4, 2]],
        dtype=torch.long)
    x = torch.tensor([[-1, 2, 3], [3, 2, 1], [1, 6, 9], [2, 3, 6], [3, 2, 8]],
                     dtype=torch.float)
    data = Data(x=x, edge_index=edge_index.t().contiguous())
    edge_index, _ = add_remaining_self_loops(data.edge_index)
    row, col = edge_index
    x = data.x
    print(x[col])
    neighbornorm = NeighborNorm(3)
    y = neighbornorm.forward(x, edge_index)
    print(y)
예제 #24
0
    def forward(self, x, edge_index, edge_weight=None, batch=None):

        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        # NxF
        x = x.unsqueeze(-1) if x.dim() == 1 else x
        # Add Self Loops
        fill_value = 1
        num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index=edge_index,
            edge_weight=edge_weight,
            fill_value=fill_value,
            num_nodes=num_nodes.sum())

        N = x.size(0)  # total num of nodes in batch

        # ExF
        x_pool = self.gnn_intra_cluster(x=x,
                                        edge_index=edge_index,
                                        edge_weight=edge_weight)
        x_pool_j = x_pool[edge_index[1]]
        x_j = x[edge_index[1]]

        #---Master query formation---
        # NxF
        X_q, _ = scatter_max(x_pool_j, edge_index[0], dim=0)
        # NxF
        M_q = self.lin_q(X_q)
        # ExF
        M_q = M_q[edge_index[0].tolist()]

        score = self.gat_att(torch.cat((M_q, x_pool_j), dim=-1))
        score = F.leaky_relu(score, self.negative_slope)
        score = softmax(score, edge_index[0], num_nodes=num_nodes.sum())

        # Sample attention coefficients stochastically.
        score = F.dropout(score, p=self.dropout_att, training=self.training)
        # ExF
        v_j = x_j * score.view(-1, 1)
        #---Aggregation---
        # NxF
        out = scatter_add(v_j, edge_index[0], dim=0)

        #---Cluster Selection
        # Nx1
        fitness = torch.sigmoid(self.gnn_score(x=out,
                                               edge_index=edge_index)).view(-1)
        perm = topk(x=fitness, ratio=self.ratio, batch=batch)
        x = out[perm] * fitness[perm].view(-1, 1)

        #---Maintaining Graph Connectivity
        batch = batch[perm]
        edge_index, edge_weight = graph_connectivity(device=x.device,
                                                     perm=perm,
                                                     edge_index=edge_index,
                                                     edge_weight=edge_weight,
                                                     score=score,
                                                     ratio=self.ratio,
                                                     batch=batch,
                                                     N=N)

        return x, edge_index, edge_weight, batch, perm
예제 #25
0
def get_diracs(data,
               N,
               n_diracs=1,
               sparse=False,
               flat=False,
               replace=True,
               receptive_field=7,
               effective_volume_range=0.1,
               max_iterations=20,
               complement=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if not sparse:
        graphcount = data.num_nodes  #number of graphs in data/batch object
        totalnodecount = data.x.shape[1]  #number of total nodes for each graph
        actualnodecount = 0  #cumulative number of nodes
        diracmatrix = torch.zeros((graphcount, totalnodecount, N),
                                  device=device)  #matrix with dirac pulses

        for k in range(graphcount):
            graph_nodes = data.mask[k].sum()  #number of nodes in the graph
            actualnodecount += graph_nodes  #might not need this, we'll see
            probabilities = torch.ones(
                (graph_nodes.item(), 1),
                device=device) / graph_nodes  #uniform probs
            node_distribution = OneHotCategorical(
                probs=probabilities.squeeze())
            node_sample = node_distribution.sample(sample_shape=(N, ))
            node_sample = torch.cat(
                (node_sample,
                 torch.zeros((N, totalnodecount - node_sample.shape[1]),
                             device=device)),
                -1)  #concat zeros to fit dataset shape
            diracmatrix[k, :] = torch.transpose(
                node_sample, dim0=-1,
                dim1=-2)  #add everything to the final matrix

        return diracmatrix

    else:
        if not is_undirected(data.edge_index):
            data.edge_index = to_undirected(data.edge_index)

        original_batch_index = data.batch
        original_edge_index = add_remaining_self_loops(
            data.edge_index, num_nodes=data.batch.shape[0])[0]
        batch_index = original_batch_index

        graphcount = data.num_graphs
        batch_prime = torch.zeros(0, device=device).long()

        r, c = original_edge_index

        global_offset = 0
        all_nodecounts = scatter_add(
            torch.ones_like(batch_index, device=device), batch_index, 0)
        recfield_vols = torch.zeros(graphcount, device=device)
        total_vols = torch.zeros(graphcount, device=device)

        for j in range(n_diracs):
            diracmatrix = torch.zeros(0, device=device)
            locationmatrix = torch.zeros(0, device=device).long()
            for k in range(graphcount):
                #get edges of current graph, remember to subtract offset
                graph_nodes = all_nodecounts[k]
                if graph_nodes == 0:
                    print("all nodecounts: ", all_nodecounts)
                graph_edges = (batch_index[r] == k)
                graph_edge_index = original_edge_index[:,
                                                       graph_edges] - global_offset
                gr, gc = graph_edge_index

                #get dirac
                randInt = np.random.choice(range(graph_nodes),
                                           N,
                                           replace=replace)
                node_sample = torch.zeros(N * graph_nodes, device=device)
                offs = torch.arange(N, device=device) * graph_nodes
                dirac_locations = (offs + torch.from_numpy(randInt).to(device))
                node_sample[dirac_locations] = 1
                #calculate receptive field volume and compare to total volume
                mask = get_mask(node_sample, graph_edge_index.detach(),
                                receptive_field).float()
                deg_graph = degree(gr, (graph_nodes.item()))
                total_volume = deg_graph.sum()
                recfield_volume = (mask * deg_graph).sum()
                volume_range = recfield_volume / total_volume
                total_vols[k] = total_volume
                recfield_vols[k] = recfield_volume
                #if receptive field volume is less than x% of total volume, resample
                for iteration in range(max_iterations):
                    randInt = np.random.choice(range(graph_nodes),
                                               N,
                                               replace=replace)
                    node_sample = torch.zeros(N * graph_nodes, device=device)
                    offs = torch.arange(N, device=device) * graph_nodes
                    dirac_locations = (offs +
                                       torch.from_numpy(randInt).to(device))
                    node_sample[dirac_locations] = 1

                    mask = get_mask(node_sample, graph_edge_index,
                                    receptive_field).float()
                    recfield_volume = (mask * deg_graph).sum()
                    volume_range = recfield_volume / total_volume

                    if volume_range > effective_volume_range:
                        recfield_vols[k] = recfield_volume
                        total_vols[k] = total_volume
                        break
                dirac_locations2 = torch.from_numpy(randInt).to(
                    device) + global_offset
                global_offset += graph_nodes

                diracmatrix = torch.cat((diracmatrix, node_sample), 0)
                locationmatrix = torch.cat((locationmatrix, dirac_locations2),
                                           0)
        locationmatrix = diracmatrix.nonzero()
        if complement:
            return Batch(batch=batch_index,
                         x=diracmatrix,
                         edge_index=original_edge_index,
                         y=data.y,
                         locations=locationmatrix,
                         volume_range=volume_range,
                         recfield_vol=recfield_vols,
                         total_vol=total_vols,
                         complement_edge_index=data.complement_edge_index)
        else:
            return Batch(batch=batch_index,
                         x=diracmatrix,
                         edge_index=original_edge_index,
                         y=data.y,
                         locations=locationmatrix,
                         volume_range=volume_range,
                         recfield_vol=recfield_vols,
                         total_vol=total_vols)
예제 #26
0
 def forward(self, x, edge_idx, n, d):
     edge_idx, _ = add_remaining_self_loops(edge_idx)
     x = spmm(x, torch.ones_like(x[0]), n, d, self.weight)
     return self.propagate(edge_idx, x=x)
예제 #27
0
    def forward(self, data, edge_dropout = None, penalty_coefficient = 0.25):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        num_graphs = batch.max().item() + 1
        row, col = edge_index     
        total_num_edges = edge_index.shape[1]
        N_size = x.shape[0]

        
        if edge_dropout is not None:
            edge_index = dropout_adj(edge_index, edge_attr = (torch.ones(edge_index.shape[1], device=device)).long(), p = edge_dropout, force_undirected=True)[0]
            edge_index = add_remaining_self_loops(edge_index, num_nodes = batch.shape[0])[0]
                
        reduced_num_edges = edge_index.shape[1]
        current_edge_percentage = (reduced_num_edges/total_num_edges)
        no_loop_index,_ = remove_self_loops(edge_index)  
        no_loop_row, no_loop_col = no_loop_index

        xinit= x.clone()
        x = x.unsqueeze(-1)
        mask = get_mask(x,edge_index,1).to(x.dtype)
        x = F.leaky_relu(self.conv1(x, edge_index))# +x
        x = x*mask
        x = self.gnorm(x)
        x = self.bn1(x)
        
            
        for conv, bn in zip(self.convs, self.bns):
            if(x.dim()>1):
                x =  x+F.leaky_relu(conv(x, edge_index))
                mask = get_mask(mask,edge_index,1).to(x.dtype)
                x = x*mask
                x = self.gnorm(x)
                x = bn(x)

        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x)) 
        x = x*mask


        xpostlin1 = x.detach()
        x = F.leaky_relu(self.lin2(x)) 
        x = x*mask


        #calculate min and max
        batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)        
        batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)

        #min-max normalize
        x = (x-batch_min)/(batch_max+1e-6-batch_min)
        probs=x
           
        x2 = x.detach()              
        deg = degree(row).unsqueeze(-1) 
        totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6
        totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6               
        x2 =  ((x2 - torch.rand_like(x, device = device))>0).float()    
        vol_1 = scatter_add(probs*deg, batch, 0)+1e-6
        card_1 = scatter_add(probs, batch,0)            
        set_size = scatter_add(x2, batch, 0)
        vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 
        total_vol_ratio = vol_hard/totalvol
        
        
        #calculating the terms for the expected distance between clique and graph
        pairwise_prodsums = torch.zeros(num_graphs, device = device)
        for graph in range(num_graphs):
            batch_graph = (batch==graph)
            pairwise_prodsums[graph] = (torch.conv1d(probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum()/2
        
        
        ###calculate loss terms
        self_sums = scatter_add((probs*probs), batch, 0, dim_size = num_graphs)
        expected_weight_G = scatter_add(probs[no_loop_row]*probs[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2.
        expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums)/1.
        expected_distance = (expected_clique_weight - expected_weight_G)        
        
        
        ###useful numbers 
        max_set_weight = (scatter_add(torch.ones_like(x)[no_loop_row], batch[no_loop_row], 0, dim_size = num_graphs)/2).squeeze(-1)                
        set_weight = (scatter_add(x2[no_loop_row]*x2[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2)+1e-6
        clique_edges_hard = (set_size*(set_size-1)/2) +1e-6
        clique_dist_hard = set_weight/clique_edges_hard
        clique_check = ((clique_edges_hard != clique_edges_hard))
        setedge_check  = ((set_weight != set_weight))      
        
        assert ((clique_dist_hard>=1.1).sum())<=1e-6, "Invalid set vol/clique vol ratio."

        ###calculate loss
        expected_loss = (penalty_coefficient)*expected_distance*0.5 - 0.5*expected_weight_G  
        

        loss = expected_loss


        retdict = {}
        
        retdict["output"] = [probs.squeeze(-1),"hist"]   #output
        retdict["Expected_cardinality"] = [card_1.mean(),"sequence"]
        retdict["Expected_cardinality_hist"] = [card_1,"hist"]
        retdict["losses histogram"] = [loss.squeeze(-1),"hist"]
        retdict["Set sizes"] = [set_size.squeeze(-1),"hist"]
        retdict["volume_hard"] = [vol_hard.mean(),"aux"] #volume2
        retdict["cardinality_hard"] = [set_size[0],"sequence"] #volumeq
        retdict["Expected weight(G)"]= [expected_weight_G.mean(), "sequence"]
        retdict["Expected maximum weight"] = [expected_clique_weight.mean(),"sequence"]
        retdict["Expected distance"]= [expected_distance.mean(), "sequence"]
        retdict["Currvol/Cliquevol"] = [clique_dist_hard.mean(),'sequence']
        retdict["Currvol/Cliquevol all graphs in batch"] = [clique_dist_hard.squeeze(-1),'hist']
        retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence']
        retdict["cardinalities"] = [cardinalities.squeeze(-1),"hist"]
        retdict["Current edge percentage"] = [torch.tensor(current_edge_percentage),'sequence']
        retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss

        return retdict
예제 #28
0
def prepare_data_for_link_prediction(datalist,
                                     train_ratio=0.8,
                                     neg_to_pos_edge_ratio=1,
                                     rnd_labeled_edges=True):
    """For each graph it splits the edges in training and testing (both with also a
    negative set of examples).
    rnd_labeled_edges=True means that the positive and negative edges for training are choosen at random (at 
    different epochs, the same graph can have different positive/negative edges chosen for training)."""
    train_data_list = []
    test_data_list = []
    for graph in datalist:
        train_graph = graph
        test_graph = train_graph.clone()

        # Create Negative edges examples
        ei_without_double_edges = remove_double_edges(graph.edge_index)
        ei_with_self_loops, _ = add_remaining_self_loops(
            ei_without_double_edges, num_nodes=graph.num_nodes)

        neg_edge_index = negative_sampling(
            edge_index=ei_with_self_loops,
            num_nodes=graph.num_nodes,
            num_neg_samples=neg_to_pos_edge_ratio *
            ei_without_double_edges.size(1),
            shuffle_neg_egdes=rnd_labeled_edges)

        num_train_pos_edges = math.floor(
            ei_without_double_edges.size(1) * train_ratio)
        num_train_neg_edges = math.floor(neg_edge_index.size(1) * train_ratio)

        # Split Positive edges
        if rnd_labeled_edges:
            perm = torch.randperm(ei_without_double_edges.size(1))
            row, col = ei_without_double_edges[0][
                perm], ei_without_double_edges[1][perm]
        else:
            row, col = ei_without_double_edges[0], ei_without_double_edges[1]
        train_graph.pos_edge_index = torch.stack(
            [row[:num_train_pos_edges], col[:num_train_pos_edges]], dim=0)
        test_graph.pos_edge_index = torch.stack(
            [row[num_train_pos_edges:], col[num_train_pos_edges:]], dim=0)

        # Update edge_index for message-passing for link prediction (no test edges)
        train_graph.edge_index = to_undirected(train_graph.pos_edge_index,
                                               num_nodes=train_graph.num_nodes)
        test_graph.edge_index = train_graph.edge_index

        # Split Negative edges
        if rnd_labeled_edges:
            perm = torch.randperm(neg_edge_index.size(1))
            row, col = neg_edge_index[0][perm], neg_edge_index[1][perm]
        else:
            row, col = neg_edge_index[0], neg_edge_index[1]
        train_graph.neg_edge_index = torch.stack(
            [row[:num_train_neg_edges], col[:num_train_neg_edges]], dim=0)
        test_graph.neg_edge_index = torch.stack(
            [row[num_train_neg_edges:], col[num_train_neg_edges:]], dim=0)

        train_data_list.append(train_graph)
        test_data_list.append(test_graph)

    return train_data_list, test_data_list