예제 #1
0
def decompress(quantbits, nz, gpu, state, nblocks):
    # model and compression params
    zdim = 8 * 16 * 16
    zrange = torch.arange(zdim)
    xdim = 32**2 * 3
    xrange = torch.arange(xdim)
    ansbits = 31  # ANS precision
    type = torch.float64  # datatype throughout compression
    device = "cpu" if gpu < 0 else f"cuda:{gpu}"  # gpu

    # set up the different channel dimension
    reswidth = 256

    # <=== MODEL ===>
    model = Model(xs=(3, 32, 32),
                  nz=nz,
                  zchannels=8,
                  nprocessing=4,
                  kernel_size=3,
                  resdepth=8,
                  reswidth=reswidth).to(device)
    model.load_state_dict(
        torch.load(f'model/params/imagenetcrop/nz4',
                   map_location=lambda storage, location: storage))
    model.eval()

    # get discretization bins for latent variables
    zendpoints, zcentres = discretize(nz, quantbits, type, device, model,
                                      "imagenetcrop")

    # get discretization bins for discretized logistic
    xbins = ImageBins(type, device, xdim)
    xendpoints = xbins.endpoints()
    xcentres = xbins.centres()

    # < ===== COMPRESSION ===>
    # initialize compression
    model.compress()

    # compression experiment params
    blocks = np.zeros((nblocks, 32, 32, 3), dtype=np.uint8)

    # <===== RECEIVER =====>
    iterator = tqdm(range(nblocks), desc="Decompression")
    for xi in iterator:
        # prior
        cdfs = logistic_cdf(zendpoints[-1].t(),
                            torch.zeros(1, device=device, dtype=type),
                            torch.ones(1, device=device, dtype=type)).t()
        pmfs = cdfs[:, 1:] - cdfs[:, :-1]
        pmfs = torch.cat(
            (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)),
            dim=1)

        # decode z
        state, zsymtop = ANS(pmfs, bits=ansbits,
                             quantbits=quantbits).decode(state)

        # < ===== Bit-Swap ====>
        # inference and generative model
        for zi in reversed(range(nz)):
            # generative model
            z = zcentres[zi, zrange, zsymtop]
            mu, scale = model.generate(zi)(given=z)
            cdfs = logistic_cdf(
                (zendpoints[zi - 1] if zi > 0 else xendpoints).t(), mu,
                scale).t()  # most expensive calculation?
            pmfs = cdfs[:, 1:] - cdfs[:, :-1]
            pmfs = torch.cat(
                (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)),
                dim=1)

            # decode z or x
            state, sym = ANS(
                pmfs, bits=ansbits,
                quantbits=quantbits if zi > 0 else 8).decode(state)

            # inference model
            input = zcentres[zi - 1, zrange,
                             sym] if zi > 0 else xcentres[xrange, sym]
            mu, scale = model.infer(zi)(given=input)
            cdfs = logistic_cdf(zendpoints[zi].t(), mu,
                                scale).t()  # most expensive calculation?
            pmfs = cdfs[:, 1:] - cdfs[:, :-1]
            pmfs = torch.cat(
                (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)),
                dim=1)

            # encode z
            state = ANS(pmfs, bits=ansbits,
                        quantbits=quantbits).encode(state, zsymtop)

            zsymtop = sym

        # reshape to 32x32 pixel-block with 3 color channels
        im = zsymtop.clone().view(3, 32, 32).detach().cpu()
        blocks[blocks.shape[0] - xi - 1] = np.array(im,
                                                    dtype=np.uint8).transpose(
                                                        (1, 2, 0))

    return blocks
예제 #2
0
def compress(quantbits, nz, gpu, blocks):
    # model and compression params
    zdim = 8 * 16 * 16
    zrange = torch.arange(zdim)
    xdim = 32**2 * 3
    xrange = torch.arange(xdim)
    ansbits = 31  # ANS precision
    type = torch.float64  # datatype throughout compression
    device = "cpu" if gpu < 0 else f"cuda:{gpu}"  # gpu

    # set up the different channel dimension
    reswidth = 256

    # <=== MODEL ===>
    model = Model(xs=(3, 32, 32),
                  nz=nz,
                  zchannels=8,
                  nprocessing=4,
                  kernel_size=3,
                  resdepth=8,
                  reswidth=reswidth).to(device)
    model.load_state_dict(
        torch.load(f'model/params/imagenetcrop/nz4',
                   map_location=lambda storage, location: storage))
    model.eval()

    # get discretization bins for latent variables
    zendpoints, zcentres = discretize(nz, quantbits, type, device, model,
                                      "imagenetcrop")

    class ToInt:
        def __call__(self, pic):
            return pic * 255

    transform_ops = transforms.Compose([transforms.ToTensor(), ToInt()])

    # get discretization bins for discretized logistic
    xbins = ImageBins(type, device, xdim)
    xendpoints = xbins.endpoints()
    xcentres = xbins.centres()

    # compression experiment params
    nblocks = blocks.shape[0]

    # < ===== COMPRESSION ===>
    # initialize compression
    model.compress()
    excess_state_len = 10000
    state = list(
        map(
            int,
            np.random.randint(
                low=1 << 16,
                high=(1 << 32) - 1,
                size=excess_state_len,
                dtype=np.uint32)))  # fill state list with 'random' bits
    state[-1] = state[-1] << 32

    # <===== SENDER =====>
    iterator = tqdm(range(nblocks), desc="Bit-Swap")
    for xi in iterator:
        x = transform_ops(Image.fromarray(blocks[xi])).to(device).view(xdim)

        # < ===== Bit-Swap ====>
        # inference and generative model
        for zi in range(nz):
            # inference model
            input = zcentres[zi - 1, zrange,
                             zsym] if zi > 0 else xcentres[xrange,
                                                           x.long()]
            mu, scale = model.infer(zi)(given=input)
            cdfs = logistic_cdf(zendpoints[zi].t(), mu, scale).t()
            pmfs = cdfs[:, 1:] - cdfs[:, :-1]
            pmfs = torch.cat(
                (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)),
                dim=1)

            # decode z
            state, zsymtop = ANS(pmfs, bits=ansbits,
                                 quantbits=quantbits).decode(state)

            # save excess state length for calculations
            # print("initial bits taken") if len(state) < excess_state_len else None
            excess_state_len = len(
                state) if len(state) < excess_state_len else excess_state_len

            # generative model
            z = zcentres[zi, zrange, zsymtop]
            mu, scale = model.generate(zi)(given=z)
            cdfs = logistic_cdf(
                (zendpoints[zi - 1] if zi > 0 else xendpoints).t(), mu,
                scale).t()  # most expensive calculation?
            pmfs = cdfs[:, 1:] - cdfs[:, :-1]
            pmfs = torch.cat(
                (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)),
                dim=1)

            # encode z or x
            state = ANS(pmfs,
                        bits=ansbits,
                        quantbits=(quantbits if zi > 0 else 8)).encode(
                            state, zsym if zi > 0 else x.long())

            zsym = zsymtop

        # prior
        cdfs = logistic_cdf(zendpoints[-1].t(),
                            torch.zeros(1, device=device, dtype=type),
                            torch.ones(1, device=device, dtype=type)).t()
        pmfs = cdfs[:, 1:] - cdfs[:, :-1]
        pmfs = torch.cat(
            (cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)),
            dim=1)

        # encode prior
        state = ANS(pmfs, bits=ansbits,
                    quantbits=quantbits).encode(state, zsymtop)

    # remove excess streams
    del state[0:excess_state_len - 1]

    return state
예제 #3
0
def compress(quantbits, nz, gpu, blocks):
    # model and compression params
    zdim = 8*16*16
    zrange = torch.arange(zdim)
    xdim = 32**2 * 3
    xrange = torch.arange(xdim)
    ansbits = 31 # ANS precision
    type = torch.float64 # datatype throughout compression
    device = f"cuda:{gpu}" # gpu

    # set up the different channel dimension
    reswidth = 256

    # seed for replicating experiment and stability
    np.random.seed(100)
    random.seed(50)
    torch.manual_seed(50)
    torch.cuda.manual_seed(50)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    # <=== MODEL ===>
    model = Model(xs = (3, 32, 32), nz=nz, zchannels=8, nprocessing=4, kernel_size=3, resdepth=8, reswidth=reswidth).to(device)
    model.load_state_dict(
        torch.load(f'model/params/imagenetcrop/nz4',
                   map_location=lambda storage, location: storage
                   )
    )
    model.eval()

    # get discretization bins for latent variables
    zendpoints, zcentres = discretize(nz, quantbits, type, device, model, "imagenet")

    # get discretization bins for discretized logistic
    xbins = ImageBins(type, device, xdim)
    xendpoints = xbins.endpoints()
    xcentres = xbins.centres()

    # <=== DATA ===>
    class ToInt:
        def __call__(self, pic):
            return pic * 255
    transform_ops = transforms.Compose([transforms.ToTensor(), ToInt()])

    # compression experiment params
    nblocks, h, w, c = blocks.shape

    # < ===== COMPRESSION ===>
    # initialize compression
    model.compress()
    state = list(map(int, np.random.randint(low=1 << 16, high=(1 << 32) - 1, size=10000, dtype=np.uint32))) # fill state list with 'random' bits
    state[-1] = state[-1] << 32
    restbits = None

    # <===== SENDER =====>
    iterator = tqdm(range(nblocks), desc="Compression")
    for xi in iterator:
        x = blocks[xi]
        x = transform_ops(Image.fromarray(x)).to(device).view(xdim)

        # < ===== Bit-Swap ====>
        # inference and generative model
        for zi in range(nz):
            # inference model
            input = zcentres[zi - 1, zrange, zsym] if zi > 0 else xcentres[xrange, x.long()]
            mu, scale = model.infer(zi)(given=input)
            cdfs = logistic_cdf(zendpoints[zi].t(), mu, scale).t() # most expensive calculation?
            pmfs = cdfs[:, 1:] - cdfs[:, :-1]
            pmfs = torch.cat((cdfs[:,0].unsqueeze(1), pmfs, 1. - cdfs[:,-1].unsqueeze(1)), dim=1)

            # decode z
            state, zsymtop = ANS(pmfs, bits=ansbits, quantbits=quantbits).decode(state)

            # save excess bits for calculations
            if xi == zi == 0:
                restbits = state.copy()
                assert len(restbits) > 1, "too few initial bits" # otherwise initial state consists of too few bits

            # generative model
            z = zcentres[zi, zrange, zsymtop]
            mu, scale = model.generate(zi)(given=z)
            cdfs = logistic_cdf((zendpoints[zi - 1] if zi > 0 else xendpoints).t(), mu, scale).t() # most expensive calculation?
            pmfs = cdfs[:, 1:] - cdfs[:, :-1]
            pmfs = torch.cat((cdfs[:,0].unsqueeze(1), pmfs, 1. - cdfs[:,-1].unsqueeze(1)), dim=1)

            # encode z or x
            state = ANS(pmfs, bits=ansbits, quantbits=(quantbits if zi > 0 else 8)).encode(state, zsym if zi > 0 else x.long())

            zsym = zsymtop

        # prior
        cdfs = logistic_cdf(zendpoints[-1].t(), torch.zeros(1, device=device, dtype=type), torch.ones(1, device=device, dtype=type)).t()
        pmfs = cdfs[:, 1:] - cdfs[:, :-1]
        pmfs = torch.cat((cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1)

        # encode prior
        state = ANS(pmfs, bits=ansbits, quantbits=quantbits).encode(state, zsymtop)

        # calculating bits
        totalbits = (len(state) - (len(restbits) - 1)) * 32

    bitsperdim = totalbits / (nblocks * h * w * c)
    return bitsperdim, state
예제 #4
0
def decompress(quantbits, nz, gpu, state):
    # model and compression params
    zdim = 8*16*16
    zrange = torch.arange(zdim)
    xdim = 32**2 * 3
    xrange = torch.arange(xdim)
    ansbits = 31 # ANS precision
    type = torch.float64 # datatype throughout compression
    device = f"cuda:{gpu}" # gpu

    # set up the different channel dimension
    reswidth = 256

    # seed for replicating experiment and stability
    np.random.seed(100)
    random.seed(50)
    torch.manual_seed(50)
    torch.cuda.manual_seed(50)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    # <=== MODEL ===>
    model = Model(xs = (3, 32, 32), nz=nz, zchannels=8, nprocessing=4, kernel_size=3, resdepth=8, reswidth=reswidth).to(device)
    model.load_state_dict(
        torch.load(f'model/params/imagenetcrop/nz4',
                   map_location=lambda storage, location: storage
                   )
    )
    model.eval()

    # get discretization bins for latent variables
    zendpoints, zcentres = discretize(nz, quantbits, type, device, model, "imagenet")

    # get discretization bins for discretized logistic
    xbins = ImageBins(type, device, xdim)
    xendpoints = xbins.endpoints()
    xcentres = xbins.centres()

    # < ===== COMPRESSION ===>
    # initialize compression
    model.compress()

    # compression experiment params
    nblocks, h, w, c = 30, 32, 32, 3
    blocks = torch.zeros(nblocks, h, w, c)

    # <===== RECEIVER =====>
    # prior
    cdfs = logistic_cdf(zendpoints[-1].t(), torch.zeros(1, device=device, dtype=type),
                        torch.ones(1, device=device, dtype=type)).t()
    pmfs = cdfs[:, 1:] - cdfs[:, :-1]
    pmfs = torch.cat((cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1)

    # decode z
    state, zsymtop = ANS(pmfs, bits=ansbits, quantbits=quantbits).decode(state)

    # <===== SENDER =====>
    iterator = tqdm(range(nblocks), desc="Compression")
    for xi in iterator:
        # < ===== Bit-Swap ====>
        # inference and generative model
        for zi in reversed(range(nz)):
            # generative model
            z = zcentres[zi, zrange, zsymtop]
            mu, scale = model.generate(zi)(given=z)
            cdfs = logistic_cdf((zendpoints[zi - 1] if zi > 0 else xendpoints).t(), mu,
                                scale).t()  # most expensive calculation?
            pmfs = cdfs[:, 1:] - cdfs[:, :-1]
            pmfs = torch.cat((cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1)

            # decode z or x
            state, sym = ANS(pmfs, bits=ansbits, quantbits=quantbits if zi > 0 else 8).decode(state)

            # inference model
            input = zcentres[zi - 1, zrange, sym] if zi > 0 else xcentres[xrange, sym]
            mu, scale = model.infer(zi)(given=input)
            cdfs = logistic_cdf(zendpoints[zi].t(), mu, scale).t()  # most expensive calculation?
            pmfs = cdfs[:, 1:] - cdfs[:, :-1]
            pmfs = torch.cat((cdfs[:, 0].unsqueeze(1), pmfs, 1. - cdfs[:, -1].unsqueeze(1)), dim=1)

            # encode z
            state = ANS(pmfs, bits=ansbits, quantbits=quantbits).encode(state, zsymtop)

            zsymtop = sym

        # reshape to 32x32 pixel-block with 3 color channels
        blocks[xi] = zsymtop.view(h, w, c)

    # return as numpy array
    return blocks.cpu().detach().numpy()