Пример #1
0
    def test_transform(self, device, dtype):
        M, N = 2, 32
        K = 30
        Bs = []
        beta = []

        bpg, bt = rand_bipartite(N // 2,
                                 N - N // 2,
                                 p=0.2,
                                 dtype=dtype,
                                 device=device,
                                 return_partition=True)
        Bs.append(bpg)
        beta.append(bt)
        mask = Bs[0].to_dense() != 0
        for i in range(M - 1):
            bpg, bt = rand_bipartite(
                N // 2,
                N - N // 2,
                p=0.2,
                dtype=dtype,
                device=device,
                return_partition=True,
            )
            dense = bpg.to_dense()
            temp_mask = dense != 0

            dense[mask] = 0
            mask = temp_mask + mask
            Bs.append(SparseTensor.from_dense(dense))
            if i % 2 == 0:
                beta.append(bt)
            else:
                beta.append(~bt)

        beta = torch.stack(beta).T

        qmf = QmfCore(bipartite_graphs=Bs, beta=beta, order=K)
        x = torch.rand(N, dtype=dtype, device=device)
        y = qmf.analyze(x)
        assert y.shape == (2**M, N, 1)

        z = qmf.synthesize(y)
        assert z.shape == (2**M, N, 1)
        assert (z.sum(0).squeeze() - x).abs().sum() != 0

        z.squeeze_()
        f_hat = z.sum(0)
        dis = (f_hat - x).abs()
        pf = snr(f_hat, x)
        ppprint(dis, pf)
Пример #2
0
    def test_dkl(self, p, dt):
        N1 = 20
        N2 = 10
        delta = 0.1
        Blil = rand_bipartite(N1, N2, p, dt).to_scipy("csr").tolil()
        Sigma = compute_sigma(Blil, delta).tocsc()

        print(dkl(Blil, Sigma, delta))
Пример #3
0
    def test_bipartite_fix_th(self):
        N1, N2 = 4, 6
        N = N1 + N2
        B = rand_udg(N, 0.6)
        flag, vtx_color, Bb = is_bipartite_fix(B.to_dense(), fix_flag=False)
        assert not flag  # if fix_flag, always bipartite

        B = rand_bipartite(N1, N2, 0.6)
        flag, vtx_color, _ = is_bipartite_fix(B.to_dense(), fix_flag=False)
        assert flag
Пример #4
0
 def test_bipartite(self):
     n_sample = 7
     N1, N2 = 4, 6
     N = N1 + N2
     for i in range(n_sample):
         B = rand_bipartite(N1, N2)
         flag, vtx_color, _ = is_bipartite_fix(B.to_scipy("csr"),
                                               fix_flag=False)
         assert flag
         NB = rand_udg(N, 0.8)  # complete graph must be non-bipartite
         flag, vtx_color, _ = is_bipartite_fix(NB.to_scipy("csr"),
                                               fix_flag=False)
         assert not flag
Пример #5
0
    def test_one_level(self, dtype, device):
        N1 = 6
        N2 = 4
        bptG, beta = rand_bipartite(N1,
                                    N2,
                                    0.5,
                                    dtype=dtype,
                                    return_partition=True)
        basis = QmfOperator([bptG.to_scipy("csr")],
                            beta.view(-1, 1),
                            order=20,
                            device=device)

        x = torch.ones(N1 + N2, 1, dtype=dtype, device=device)
        y = basis.transform(x)
        z = basis.inverse_transform(y)
        print("\nsnr: ",
              snr(z.permute(-1, -2), x.permute(-1, -2)).item(), "dB.")
        print("dis: ", (z - x).abs().sum())
Пример #6
0
    def test_one_level(self, dtype, device):
        N1 = 60
        N2 = 40
        bptG, beta = rand_bipartite(N1,
                                    N2,
                                    0.2,
                                    dtype=dtype,
                                    return_partition=True)
        basis = BiorthOperator([bptG.to_scipy("csr")],
                               beta.view(-1, 1),
                               k=2,
                               device=device)

        x = torch.ones(N1 + N2, 1, dtype=dtype, device=device)
        y = basis.transform(x)
        z = basis.inverse_transform(y)
        print("\nsnr: ",
              snr(z.permute(-1, -2), x.permute(-1, -2)).item(), "dB.")
        print("dis: ", (z - x).abs().sum())
        self.display_density(basis.operator, basis.inv_operator)
Пример #7
0
def test_is_bipartite2(device):
    ts_spm = rand_bipartite(4, 6, device=device)
    assert is_bipartite(ts_spm)[0]
Пример #8
0
 def test_rand_bipartite(self, device, dtype, density):
     N1 = 6
     N2 = 7
     rand_bipartite(N1, N2, density, dtype, device)