def decode_svd(codebook, domain_blocks):
    # if domain_blocks.mode != 'equipartition4' or\
    #    domain_blocks.mode !=:
    #    raise Exception("encode_svd only supports 'equipartition4' partitioning")

    num_weights = len(domain_blocks[0].ravel())
    X = np.zeros([num_weights, len(domain_blocks)])

    for i in range(len(domain_blocks)):
        X[:, i] = np.ravel(domain_blocks[i])

    u, s, vh = randomized_svd(X, 
                            n_components=min(X.shape[1], 1000),
                            n_iter=5,
                            random_state=None)

    N = int(np.sqrt(len(codebook)))
    img = utils.init_range_image(N, domain_blocks)
    range_blocks = utils.Partition(img, mode=domain_blocks.mode)

    nrows = int(np.sqrt(num_weights))
    ncols = nrows
    for ridx, weights in enumerate(codebook):
        range_blocks[ridx] = np.dot(u, weights[:, np.newaxis])\
                             .reshape(nrows, ncols)

    return range_blocks
def decode_nmf(codebook, domain_blocks):
    # if domain_blocks.mode != 'equipartition4' or\
    #    domain_blocks.mode !=:
    #    raise Exception("encode_svd only supports 'equipartition4' partitioning")
    num_weights = len(domain_blocks[0].ravel())
    X = np.zeros([num_weights, len(domain_blocks)])

    for i in range(len(domain_blocks)):
        X[:, i] = np.ravel(domain_blocks[i])
    X = np.abs(X)
    model = NMF(n_components=16, init='random', random_state=0)
    _ = model.fit_transform(X)
    u = model.components_[:,:16]

    N = int(np.sqrt(len(codebook)))
    img = utils.init_range_image(N, domain_blocks)
    range_blocks = utils.Partition(img, mode=domain_blocks.mode)

    nrows = int(np.sqrt(num_weights))
    ncols = nrows
    for ridx, weights in enumerate(codebook):
        range_blocks[ridx] = np.dot(u, weights[:, np.newaxis])\
                             .reshape(nrows, ncols)

    return range_blocks
def decode(codebook, domain_blocks):
    """Decodes the codebook into estimates of the range blocks"""

    N = int(np.sqrt(len(codebook)))
    if N != np.sqrt(len(codebook)):
        raise Exception("Codebook size must correspond to a square range image")
    
    # create a range block partition to populate
    img = utils.init_range_image(N, domain_blocks)

    range_blocks = utils.Partition(img, mode=domain_blocks.mode)

    for ridx, (didx, permtype, alpha, t0) in enumerate(codebook):
        permtype = int(permtype)
        dref = domain_blocks[didx]
        db = permute(dref)[permtype]
        range_blocks[ridx] = (alpha * db) + t0
        
    return range_blocks