def test_ScalingNshifting_saveload():
    p = source.Gaussian([2, 2])

    s = float(np.random.randint(1, 11))
    t = float(np.random.randint(0, 11))
    f = flow.ScalingNshifting(s, t)
    f.prior = p

    pp = source.Gaussian([2, 2])
    s = float(np.random.randint(1, 11))
    t = float(np.random.randint(0, 11))
    blankf = flow.ScalingNshifting(s, t)
    blankf.prior = pp

    saveload(f, blankf)
def test_grad():
    decimal = flow.ScalingNshifting(256, -128)
    p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient)

    maskList = []
    for n in range(4):
        if n % 2 == 0:
            b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32)
        else:
            b = 1 - b
        maskList.append(b)
    maskList = torch.cat(maskList, 0).to(torch.float32)
    tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)]
    f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p)

    fcopy = deepcopy(f)
    fcopy.rounding = torch.round

    field = p.sample(100).detach()
    cfield = deepcopy(field).requires_grad_()
    field.requires_grad_()
    xfield, _ = f.inverse(field)
    xcfield, _ = fcopy.inverse(cfield)
    L = xfield.sum()
    Lc = xcfield.sum()
    L.backward()
    Lc.backward()

    ou = [term for term in f.parameters()]
    ouc = [term for term in fcopy.parameters()]
    assert not np.all(ou[-1].grad.detach().numpy() == ouc[-1].grad.detach().numpy())
def test_saveload():
    decimal = flow.ScalingNshifting(256, -128)
    p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient)

    maskList = []
    for n in range(4):
        if n % 2 == 0:
            b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32)
        else:
            b = 1 - b
        maskList.append(b)
    maskList = torch.cat(maskList, 0).to(torch.float32)
    tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)]
    f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p)

    p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient)

    maskList = []
    for n in range(4):
        if n % 2 == 0:
            b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32)
        else:
            b = 1 - b
        maskList.append(b)
    maskList = torch.cat(maskList, 0).to(torch.float32)
    tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)]
    blankf = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p)

    saveload(f,blankf)
def test_ScalingNshifting():
    p = source.Gaussian([2, 2])
    s = float(np.random.randint(1, 11))
    t = float(np.random.randint(0, 11))
    f = flow.ScalingNshifting(s, t)
    f.prior = p

    bijective(f)
def test_bijective():
    decimal = flow.ScalingNshifting(256, -128)

    layerList = []
    for i in range(4 * 2):
        f = torch.nn.Sequential(torch.nn.Conv2d(9, 9, 3, padding=1),
                                torch.nn.ReLU(inplace=True),
                                torch.nn.Conv2d(9, 9, 1, padding=0),
                                torch.nn.ReLU(inplace=True),
                                torch.nn.Conv2d(9, 3, 3, padding=1))
        layerList.append(f)

    meanNNlist = []
    scaleNNlist = []
    meanNNlist.append(
        torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1),
                            torch.nn.ReLU(inplace=True),
                            torch.nn.Conv2d(9, 9, 1, padding=0),
                            torch.nn.ReLU(inplace=True)))
    scaleNNlist.append(
        torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1),
                            torch.nn.ReLU(inplace=True),
                            torch.nn.Conv2d(9, 9, 1, padding=0),
                            torch.nn.ReLU(inplace=True)))

    t = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 2, None, 5,
                        decimal, utils.roundingWidentityGradient)

    samples = torch.randint(0, 255, (100, 3, 8, 8)).float()

    zSamples, _ = t.inverse(samples)
    rcnSamples, _ = t.forward(zSamples)
    prob = t.logProbability(samples)

    assert_allclose(samples.detach().numpy(), rcnSamples.detach().numpy())

    # Test the depth argument
    t = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 2, 2, 5,
                        decimal, utils.roundingWidentityGradient)

    samples = torch.randint(0, 255, (100, 3, 8, 8)).float()

    zSamples, _ = t.inverse(samples)
    rcnSamples, _ = t.forward(zSamples)
    #prob = t.logProbability(samples)

    assert_allclose(samples.detach().numpy(), rcnSamples.detach().numpy())
def test_integer():
    decimal = flow.ScalingNshifting(256, -128)
    p = source.MixtureDiscreteLogistic([3, 32, 32], 5, decimal, utils.roundingWidentityGradient)

    maskList = []
    for n in range(4):
        if n % 2 == 0:
            b = torch.cat([torch.zeros(3 * 32 * 16), torch.ones(3 * 32 * 16)])[torch.randperm(3 * 32 * 32)].reshape(1, 3, 32, 32)
        else:
            b = 1 - b
        maskList.append(b)
    maskList = torch.cat(maskList, 0).to(torch.float32)
    tList = [utils.SimpleMLPreshape([3 * 32 * 16, 200, 500, 3 * 32 * 16], [nn.ELU(), nn.ELU(), None]) for _ in range(4)]
    f = flow.DiscreteNICE(maskList, tList, decimal, utils.roundingWidentityGradient, p)

    x, _ = f.sample(100)
    assert np.all(np.equal(np.mod(x.detach().numpy(), 1), 0))

    zx, _ = f.inverse(x)
    assert np.all(np.equal(np.mod(zx.detach().numpy(), 1), 0))

    xzx, _ = f.forward(zx)
    assert np.all(np.equal(np.mod(xzx.detach().numpy(), 1), 0))
def test_grad():
    length = 8
    channel = 3
    decimal = flow.ScalingNshifting(256, -128)
    p1 = source.DiscreteLogistic([channel, 16, 3], decimal, rounding=utils.roundingWidentityGradient)
    p2 = source.DiscreteLogistic([channel, 4, 3], decimal, rounding=utils.roundingWidentityGradient)
    p3 = source.MixtureDiscreteLogistic([channel, 1, 4], 5, decimal, rounding=utils.roundingWidentityGradient)

    P = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=2)

    x = P.sample(100)
    logp = P.logProbability(x)
    L = logp.mean()
    L.backward()

    assert p1.mean.grad.sum().detach().item() != 0
    assert p2.mean.grad.sum().detach().item() != 0
    assert p3.mean.grad.sum().detach().item() != 0

    assert p1.logscale.grad.sum().detach().item() != 0
    assert p2.logscale.grad.sum().detach().item() != 0
    assert p3.logscale.grad.sum().detach().item() != 0

    assert p3.mixing.grad.sum().detach().item() != 0
def test_saveload():
    shapeList2D = [3] + [12] * (1 + 1) + [3 * 3]
    shapeList1D = [3] + [12] * (1 + 1) + [3]

    def buildLayers2D(shapeList):
        layers = []
        for no, chn in enumerate(shapeList[:-1]):
            if no != 0 and no != len(shapeList) - 2:
                layers.append(torch.nn.Conv2d(chn, shapeList[no + 1], 1))
            else:
                layers.append(
                    torch.nn.Conv2d(chn, shapeList[no + 1], 3, padding=1))
            if no != len(shapeList) - 2:
                layers.append(torch.nn.ReLU(inplace=True))
        return layers

    def buildLayers1D(shapeList):
        layers = []
        for no, chn in enumerate(shapeList[:-1]):
            if no != 0 and no != len(shapeList) - 2:
                layers.append(torch.nn.Conv1d(chn, shapeList[no + 1], 1))
            else:
                layers.append(
                    torch.nn.Conv1d(chn, shapeList[no + 1], 3, padding=1))
            if no != len(shapeList) - 2:
                layers.append(torch.nn.ReLU(inplace=True))
        layers = torch.nn.Sequential(*layers)
        #torch.nn.init.zeros_(layers[-1].weight)
        #torch.nn.init.zeros_(layers[-1].bias)
        return layers

    decimal = flow.ScalingNshifting(256, 0)

    layerList = []
    for i in range(2 * 2):
        layerList.append(buildLayers1D(shapeList1D))

    meanNNlist = []
    scaleNNlist = []
    layers = buildLayers2D(shapeList2D)
    meanNNlist.append(torch.nn.Sequential(*layers))
    layers = buildLayers2D(shapeList2D)
    scaleNNlist.append(torch.nn.Sequential(*layers))

    t = flow.OneToTwoMERA(8,
                          layerList,
                          meanNNlist,
                          scaleNNlist,
                          2,
                          None,
                          5,
                          decimal=decimal,
                          rounding=utils.roundingWidentityGradient)

    decimal = flow.ScalingNshifting(256, 0)

    layerList = []
    for i in range(2 * 2):
        layerList.append(buildLayers1D(shapeList1D))

    meanNNlist = []
    scaleNNlist = []
    layers = buildLayers2D(shapeList2D)
    meanNNlist.append(torch.nn.Sequential(*layers))
    layers = buildLayers2D(shapeList2D)
    scaleNNlist.append(torch.nn.Sequential(*layers))

    tt = flow.OneToTwoMERA(8,
                           layerList,
                           meanNNlist,
                           scaleNNlist,
                           2,
                           None,
                           5,
                           decimal=decimal,
                           rounding=utils.roundingWidentityGradient)

    samples = torch.randint(0, 255, (100, 3, 8, 8)).float()

    torch.save(t.save(), "testsaving.saving")
    tt.load(torch.load("testsaving.saving"))

    tzSamples, _ = t.inverse(samples)
    ttzSamples, _ = tt.inverse(samples)

    rcnSamples, _ = t.forward(tzSamples)
    ttrcnSamples, _ = tt.forward(ttzSamples)

    assert_allclose(tzSamples.detach().numpy(), ttzSamples.detach().numpy())
    assert_allclose(samples.detach().numpy(), rcnSamples.detach().numpy())
    assert_allclose(rcnSamples.detach().numpy(), ttrcnSamples.detach().numpy())
def test_wavelet():
    def back01(tensor):
        ten = tensor.clone().float()
        ten = ten.view(ten.shape[0] * ten.shape[1], -1)
        ten -= ten.min(1, keepdim=True)[0]
        ten /= ten.max(1, keepdim=True)[0]
        ten = ten.view(tensor.shape)
        return ten

    # yet another renorm fn
    def batchNorm(tensor, base=1.0):
        m = nn.BatchNorm2d(tensor.shape[1], affine=False)
        return m(tensor).float() + base

    renormFn = lambda x: back01(batchNorm(x))

    def im2grp(t):
        return t.reshape(t.shape[0], t.shape[1], t.shape[2] // 2, 2,
                         t.shape[3] // 2,
                         2).permute([0, 1, 2, 4, 3,
                                     5]).reshape(t.shape[0], t.shape[1], -1, 4)

    def grp2im(t):
        return t.reshape(t.shape[0], t.shape[1], int(t.shape[2]**0.5),
                         int(t.shape[2]**0.5), 2,
                         2).permute([0, 1, 2, 4, 3,
                                     5]).reshape(t.shape[0], t.shape[1],
                                                 int(t.shape[2]**0.5) * 2,
                                                 int(t.shape[2]**0.5) * 2)

    decimal = flow.ScalingNshifting(256, 0)

    psudoRounding = torch.nn.Identity()

    IMG = Image.open('./etc/lena512color.tiff')
    IMG = torch.from_numpy(np.array(IMG)).permute([2, 0, 1])
    IMG = IMG.reshape(1, *IMG.shape).float()

    v = IMG

    def buildTransMatrix(n):
        core = torch.tensor([[0.5, 0.5], [-1, 1]])
        gap = torch.zeros(2, n)
        return torch.cat([core if i % 2 == 0 else gap for i in range(n - 1)],
                         -1).reshape(2, n // 2, n).permute([1, 0,
                                                            2]).reshape(n, n)

    depth = int(math.log(v.shape[-1], 2))
    up = v

    blockSize = v.shape[-1]

    UR = []
    DL = []
    DR = []

    for i in range(depth):
        transMatrix = buildTransMatrix(blockSize)
        for _ in range(2):
            up = torch.matmul(up, transMatrix.t())
            up = up.permute([0, 1, 3, 2])
        blockSize //= 2

        _x = im2grp(up)
        ul = _x[:, :, :, 0].reshape(*_x.shape[:2], int(_x.shape[2]**0.5),
                                    int(_x.shape[2]**0.5)).contiguous()
        ur = _x[:, :, :, 1].reshape(*_x.shape[:2], int(_x.shape[2]**0.5),
                                    int(_x.shape[2]**0.5)).contiguous()
        dl = _x[:, :, :, 2].reshape(*_x.shape[:2], int(_x.shape[2]**0.5),
                                    int(_x.shape[2]**0.5)).contiguous()
        dr = _x[:, :, :, 3].reshape(*_x.shape[:2], int(_x.shape[2]**0.5),
                                    int(_x.shape[2]**0.5)).contiguous()

        UR.append(ur)
        DL.append(dl)
        DR.append(dr)

        up = ul

    ul = up

    for no in reversed(range(depth)):
        ur = UR[no].reshape(*ul.shape, 1)
        dl = DL[no].reshape(*ul.shape, 1)
        dr = DR[no].reshape(*ul.shape, 1)
        ul = ul.reshape(*ul.shape, 1)

        _x = torch.cat([ul, ur, dl, dr], -1).reshape(*ul.shape[:2], -1, 4)
        ul = grp2im(_x).contiguous()

    transV = ul
    '''
    ul = ul
    UR = []
    DL = []
    DR = []
    for _ in range(depth):
        _x = im2grp(ul)
        ul = _x[:, :, :, 0].reshape(*_x.shape[:2], int(_x.shape[2] ** 0.5), int(_x.shape[2] ** 0.5)).contiguous()
        ur = _x[:, :, :, 1].reshape(*_x.shape[:2], int(_x.shape[2] ** 0.5), int(_x.shape[2] ** 0.5)).contiguous()
        dl = _x[:, :, :, 2].reshape(*_x.shape[:2], int(_x.shape[2] ** 0.5), int(_x.shape[2] ** 0.5)).contiguous()
        dr = _x[:, :, :, 3].reshape(*_x.shape[:2], int(_x.shape[2] ** 0.5), int(_x.shape[2] ** 0.5)).contiguous()
        UR.append(renormFn(ur))
        DL.append(renormFn(dl))
        DR.append(renormFn(dr))

    #ul = back01(backMeanStd(batchNorm(ul, 0)))
    ul = renormFn(ul)
    #ul = back01(clip(backMeanStd(batchNorm(ul))))

    for no in reversed(range(depth)):

        ur = UR[no]
        dl = DL[no]
        dr = DR[no]

        upper = torch.cat([ul, ur], -1)
        down = torch.cat([dl, dr], -1)
        ul = torch.cat([upper, down], -2)

    # convert zremaoin to numpy array
    zremain = ul.permute([0, 2, 3, 1]).detach().cpu().numpy()

    waveletPlot = plt.figure(figsize=(8, 8))
    waveletAx = waveletPlot.add_subplot(111)
    waveletAx.imshow(zremain[0])
    plt.axis('off')
    plt.savefig('./testWavelet.pdf', bbox_inches="tight", pad_inches=0)
    plt.close()
    '''

    initMethods = []
    initMethods.append(lambda: harrInitMethod1(3))
    initMethods.append(lambda: harrInitMethod2(3))

    orders = [True, False]

    layerList = []
    for j in range(2):
        layerList.append(
            buildWaveletLayers(initMethods[j], 3, 12, 1, orders[j]))

    shapeList2D = [3] + [12] * (1 + 1) + [3 * 3]
    shapeList1D = [3] + [12] * (1 + 1) + [3]

    def buildLayers2D(shapeList):
        layers = []
        for no, chn in enumerate(shapeList[:-1]):
            if no != 0 and no != len(shapeList) - 2:
                layers.append(torch.nn.Conv2d(chn, shapeList[no + 1], 1))
            else:
                layers.append(
                    torch.nn.Conv2d(chn, shapeList[no + 1], 3, padding=1))
            if no != len(shapeList) - 2:
                layers.append(torch.nn.ReLU(inplace=True))
        return layers

    def buildLayers1D(shapeList):
        layers = []
        for no, chn in enumerate(shapeList[:-1]):
            if no != 0 and no != len(shapeList) - 2:
                layers.append(torch.nn.Conv1d(chn, shapeList[no + 1], 1))
            else:
                layers.append(
                    torch.nn.Conv1d(chn, shapeList[no + 1], 3, padding=1))
            if no != len(shapeList) - 2:
                layers.append(torch.nn.ReLU(inplace=True))
        layers = torch.nn.Sequential(*layers)
        torch.nn.init.zeros_(layers[-1].weight)
        torch.nn.init.zeros_(layers[-1].bias)
        return layers

    # repeat, add one more layer of NICE
    for _ in range(2):
        layerList.append(buildLayers1D(shapeList1D))

    meanNNlist = []
    scaleNNlist = []
    layers = buildLayers2D(shapeList2D)
    meanNNlist.append(torch.nn.Sequential(*layers))
    layers = buildLayers2D(shapeList2D)
    scaleNNlist.append(torch.nn.Sequential(*layers))
    torch.nn.init.zeros_(meanNNlist[-1][-1].weight)
    torch.nn.init.zeros_(meanNNlist[-1][-1].bias)
    torch.nn.init.zeros_(scaleNNlist[-1][-1].weight)
    torch.nn.init.zeros_(scaleNNlist[-1][-1].bias)

    f = flow.OneToTwoMERA(v.shape[-1],
                          layerList,
                          meanNNlist,
                          scaleNNlist,
                          repeat=2,
                          depth=depth,
                          nMixing=5,
                          decimal=decimal,
                          rounding=psudoRounding.forward)

    vpp = f.inverse(v)[0]

    assert_allclose(vpp.detach().numpy(), transV.detach().numpy())

    # Test depth
    vp = IMG

    depth = 2
    up = vp
    blockSize = v.shape[-1]

    UR = []
    DL = []
    DR = []

    for i in range(depth):
        transMatrix = buildTransMatrix(blockSize)
        for _ in range(2):
            up = torch.matmul(up, transMatrix.t())
            up = up.permute([0, 1, 3, 2])
        blockSize //= 2

        _x = im2grp(up)
        ul = _x[:, :, :, 0].reshape(*_x.shape[:2], int(_x.shape[2]**0.5),
                                    int(_x.shape[2]**0.5)).contiguous()
        ur = _x[:, :, :, 1].reshape(*_x.shape[:2], int(_x.shape[2]**0.5),
                                    int(_x.shape[2]**0.5)).contiguous()
        dl = _x[:, :, :, 2].reshape(*_x.shape[:2], int(_x.shape[2]**0.5),
                                    int(_x.shape[2]**0.5)).contiguous()
        dr = _x[:, :, :, 3].reshape(*_x.shape[:2], int(_x.shape[2]**0.5),
                                    int(_x.shape[2]**0.5)).contiguous()

        UR.append(ur)
        DL.append(dl)
        DR.append(dr)

        up = ul

    ul = up

    for no in reversed(range(depth)):
        ur = UR[no].reshape(*ul.shape, 1)
        dl = DL[no].reshape(*ul.shape, 1)
        dr = DR[no].reshape(*ul.shape, 1)
        ul = ul.reshape(*ul.shape, 1)

        _x = torch.cat([ul, ur, dl, dr], -1).reshape(*ul.shape[:2], -1, 4)
        ul = grp2im(_x).contiguous()

    transVp = ul

    fp = flow.OneToTwoMERA(8,
                           layerList,
                           meanNNlist,
                           scaleNNlist,
                           2,
                           depth,
                           5,
                           decimal=decimal,
                           rounding=psudoRounding.forward)

    vpp = fp.inverse(vp)[0]

    assert_allclose(vpp.detach().numpy(), transVp.detach().numpy())
    '''
Example #10
0
else:
    # load saved parameters, and decoding them to mem
    with open(rootFolder + "/parameter.json", 'r') as f:
        config = json.load(f)
        locals().update(config)

# Building the target dataset
if target == "CIFAR":
    # Define dimensions
    targetSize = [3, 32, 32]
    dimensional = 2
    channel = targetSize[0]
    blockLength = targetSize[-1]

    # Define nomaliziation and decimal
    decimal = flow.ScalingNshifting(256, 0)
    rounding = utils.roundingWidentityGradient

    # Building train & test datasets
    lambd = lambda x: (x * 255).byte().to(torch.float32).to(device)
    trainsetTransform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Lambda(lambd)])
    trainTarget = torchvision.datasets.CIFAR10(root='./data/cifar', train=True, download=True, transform=trainsetTransform)
    testTarget = torchvision.datasets.CIFAR10(root='./data/cifar', train=False, download=True, transform=trainsetTransform)
    targetTrainLoader = torch.utils.data.DataLoader(trainTarget, batch_size=batch, shuffle=True)
    targetTestLoader = torch.utils.data.DataLoader(testTarget, batch_size=batch, shuffle=False)
elif target == "ImageNet32":
    # Define dimensions
    targetSize = [3, 32, 32]
    dimensional = 2
    channel = targetSize[0]
    blockLength = targetSize[-1]
Example #11
0
if HUE:
    lambd = lambda x: (x * 255).byte().to(torch.float32).to(device)
else:
    lambd = lambda x: utils.rgb2ycc(
        (x * 255).byte().float(), True).to(torch.float32).to(device)

# Building the target dataset
if target == "CIFAR":
    # Define dimensions
    targetSize = [3, 32, 32]
    dimensional = 2
    channel = targetSize[0]
    blockLength = targetSize[-1]

    # Define nomaliziation and decimal
    decimal = flow.ScalingNshifting(256, -128)
    rounding = utils.roundingWidentityGradient

    # Building train & test datasets
    trainsetTransform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambd)
    ])
    trainTarget = torchvision.datasets.CIFAR10(root='./data/cifar',
                                               train=True,
                                               download=True,
                                               transform=trainsetTransform)
    testTarget = torchvision.datasets.CIFAR10(root='./data/cifar',
                                              train=False,
                                              download=True,
                                              transform=trainsetTransform)
def test_hierarchyPrior():

    class UniTestPrior(source.Source):
        def __init__(self, nvars, element, name="UniTestPrior"):
            super(UniTestPrior, self).__init__(nvars, 1.0, name)
            self.element = torch.nn.Parameter(torch.tensor(element), requires_grad=False)

        def sample(self, batchSize):
            return torch.ones([batchSize] + self.nvars).to(self.element).float() * self.element

        def _energy(self, z):
            return (torch.tensor([2])**self.element * np.prod(z.shape[2:]))

    length = 32
    channel = 3
    decimal = flow.ScalingNshifting(256, -128)
    p1 = source.DiscreteLogistic([channel, 256, 3], decimal, rounding=utils.roundingWidentityGradient)
    p2 = source.DiscreteLogistic([channel, 64, 3], decimal, rounding=utils.roundingWidentityGradient)
    p3 = source.DiscreteLogistic([channel, 16, 3], decimal, rounding=utils.roundingWidentityGradient)
    p4 = source.DiscreteLogistic([channel, 4, 3], decimal, rounding=utils.roundingWidentityGradient)
    p5 = source.MixtureDiscreteLogistic([channel, 1, 4], 5, decimal, rounding=utils.roundingWidentityGradient)

    P = source.HierarchyPrior(channel, length, [p1, p2, p3, p4, p5], repeat=1)

    x = P.sample(100)
    logp = P.logProbability(x)

    import math
    zparts = []
    for no in range(int(math.log(length, 2))):
        _, parts = utils.dispatch(P.factorOutIList[no], P.factorOutJList[no], x)
        zparts.append(parts)

    rcnX = torch.zeros_like(x)
    for no in range(int(math.log(length, 2))):
        part = zparts[no]
        rcnX = utils.collect(P.factorOutIList[no], P.factorOutJList[no], rcnX, part)

    assert_allclose(x.detach(), rcnX.detach())

    length = 8

    p1 = UniTestPrior([channel, 16, 3], 1)
    p2 = UniTestPrior([channel, 4, 3], 2)
    p3 = UniTestPrior([channel, 1, 4], 3)

    Pp = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=2)

    x = Pp.sample(1)
    logp = Pp.logProbability(x)

    target = np.array([[3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1]])
    assert_allclose(x[0, 0].detach().numpy(), target)
    assert logp == -(16 * 3 * 2**1 + 4 * 3 * 2**2 + 1 * 4 * 2**3)

    p1 = UniTestPrior([channel, 16, 3], 1)
    p2 = UniTestPrior([channel, 4, 3], 2)
    p3 = UniTestPrior([channel, 1, 4], 3)

    Ppodd = source.HierarchyPrior(channel, length, [p1, p2, p3], repeat=1)

    x = Ppodd.sample(1)
    logp = Ppodd.logProbability(x)

    target = np.array([[3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [3, 1, 2, 1, 3, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1], [2, 1, 2, 1, 2, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1]])
    assert_allclose(x[0, 0].detach().numpy(), target)
    assert logp == -(16 * 3 * 2**1 + 4 * 3 * 2**2 + 1 * 4 * 2**3)
Example #13
0
def test_saveload():
    decimal = flow.ScalingNshifting(256, -128)

    layerList = []
    for i in range(4):
        f = torch.nn.Sequential(torch.nn.Conv2d(9, 9, 3, padding=1),
                                torch.nn.ReLU(inplace=True),
                                torch.nn.Conv2d(9, 9, 1, padding=0),
                                torch.nn.ReLU(inplace=True),
                                torch.nn.Conv2d(9, 3, 3, padding=1))
        layerList.append(f)

    meanNNlist = []
    scaleNNlist = []
    meanNNlist.append(
        torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1),
                            torch.nn.ReLU(inplace=True),
                            torch.nn.Conv2d(9, 9, 1, padding=0),
                            torch.nn.ReLU(inplace=True)))
    scaleNNlist.append(
        torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1),
                            torch.nn.ReLU(inplace=True),
                            torch.nn.Conv2d(9, 9, 1, padding=0),
                            torch.nn.ReLU(inplace=True)))

    t = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 1, None, 5,
                        decimal, utils.roundingWidentityGradient)

    decimal = flow.ScalingNshifting(256, -128)

    layerList = []
    for i in range(4):
        f = torch.nn.Sequential(torch.nn.Conv2d(9, 9, 3, padding=1),
                                torch.nn.ReLU(inplace=True),
                                torch.nn.Conv2d(9, 9, 1, padding=0),
                                torch.nn.ReLU(inplace=True),
                                torch.nn.Conv2d(9, 3, 3, padding=1))
        layerList.append(f)

    meanNNlist = []
    scaleNNlist = []
    meanNNlist.append(
        torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1),
                            torch.nn.ReLU(inplace=True),
                            torch.nn.Conv2d(9, 9, 1, padding=0),
                            torch.nn.ReLU(inplace=True)))
    scaleNNlist.append(
        torch.nn.Sequential(torch.nn.Conv2d(3, 9, 3, padding=1),
                            torch.nn.ReLU(inplace=True),
                            torch.nn.Conv2d(9, 9, 1, padding=0),
                            torch.nn.ReLU(inplace=True)))

    tt = flow.SimpleMERA(8, layerList, meanNNlist, scaleNNlist, 1, None, 5,
                         decimal, utils.roundingWidentityGradient)

    samples = torch.randint(0, 255, (100, 3, 8, 8)).float()

    torch.save(t.save(), "testsaving.saving")
    tt.load(torch.load("testsaving.saving"))

    tzSamples, _ = t.inverse(samples)
    ttzSamples, _ = tt.inverse(samples)

    rcnSamples, _ = t.forward(tzSamples)
    ttrcnSamples, _ = tt.forward(ttzSamples)

    assert_allclose(tzSamples.detach().numpy(), ttzSamples.detach().numpy())
    assert_allclose(samples.detach().numpy(), rcnSamples.detach().numpy())
    assert_allclose(rcnSamples.detach().numpy(), ttrcnSamples.detach().numpy())