Esempio n. 1
0
def normalized_cut_2d(edge_index, pos):
    # TK: Avoid expansion, just refer by index
    # row, col = edge_index
    row = edge_index[0]
    col = edge_index[1]
    edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
    return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))
def test_normalized_cut():
    row = torch.LongTensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4])
    col = torch.LongTensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3])
    edge_attr = torch.Tensor([3, 3, 6, 3, 6, 1, 3, 2, 1, 2])
    expected_output = [4, 4, 5, 2.5, 5, 1, 2.5, 2, 1, 2]

    output = normalized_cut(torch.stack([row, col], dim=0), edge_attr)
    assert output.tolist() == expected_output
Esempio n. 3
0
def normalized_cut_2d(edge_index, pos):
    row, col = edge_index
    edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
    return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))
Esempio n. 4
0
def normalized_cut_2d(edge_index: torch.LongTensor,
                      num_nodes: int) -> torch.FloatTensor:
    edge_attr = torch.ones(edge_index.shape[1], device=edge_index.device)
    return normalized_cut(edge_index, edge_attr, num_nodes=num_nodes)
Esempio n. 5
0
 def normalized_cut_2d(self, edge_index, pos):
     """ calculate the normalized cut 2d 
     """
     row, col = edge_index
     edge_attr = torch.norm(pos[row] - pos[col], dim=1)
     return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))