    def __init__(self, alpha, parts, numpy=False, sparse=True):
        alpha: tuple of ints, the weak partition of 8 into 3 parts
        parts: tuple of ints, the partitions of each part of alpha
        cached_loc: string, location of the cached pickle file of S_8 mod S_alpha irreps
        self.alpha = alpha
        self.parts = parts
        self.cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha))
        self.cyc_irrep_func = cyclic_irreps(alpha)
        self.yor_dict = None

        # cache orientation tuple -> cyclic irrep
        self.cyclic_irreps_re = {}
        self.cyclic_irreps_im = {}

        # also cache the cyclic irreps
        if numpy:
            pkl_loc = IRREP_LOC_FMT.format(alpha, parts)
            self.yor_dict = load_pkl(pkl_loc)
        elif sparse:
            pkl_loc = IRREP_SP_LOC_FMT.format(alpha, parts)
            self.yor_dict = load_sparse_pkl(pkl_loc)
            # TODO: deprecate
            print('neither sparse nor numpy')
            pkl_loc = IRREP_LOC_FMT.format(alpha, parts)
            self.np_yor_dict = load_pkl(pkl_loc)
def benchmark(n):
    Benchmark time/memory usage for generating all YoungTableau for S_8
    tstart = time.time()
    _partitions = partitions(n)
    s_n = sn(n)
    times = []
    for idx, p in enumerate(_partitions):
        start = time.time()
        sdict = {}
        f = FerrersDiagram(p)
        if os.path.exists('/local/hopan/irreps/s_9/{}.pkl'.format(p)):
            print('Skipping {}'.format(p))
        for perm in s_n:
            start = time.time()
            y = yor(f, perm)
            end = time.time()
            if random.random() > 0.1 and len(times) < 1000:
                times.append(end - start)
            if len(times) >= 1000:
            sdict[perm.tup_rep] = y

        done = time.time() - start
        print('Elapsed: {:.2f}mins | Done {} / {} | Partition: {}'.format(done / 60., idx, len(_partitions), p))

        with open('/local/hopan/irreps/s_{}/{}.pkl'.format(n, p), 'wb') as f:
            pickle.dump(sdict, f, protocol=pickle.HIGHEST_PROTOCOL)

    tend = time.time() - tstart
    print('Total time compute yor matrices for S_{}: {:3f}'.format(n, tend))
def text_split_transform(fsplit_lst, irrep_dict, alpha, parts, mem_dict=None):
    fsplit_pkl: list of split file names of the distance values for a chunk of the total distance values
    irrep_dict: irrep dict
    alpha: weak partition
    parts: list/iterable of partitions of the parts of alpha
    print('     Computing transform on splits: {}'.format(fsplit_lst))
    cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha))
    save_dict = {}
    cyc_irrep_func = cyclic_irreps(alpha)
    pid = os.getpid()

    for split_f in fsplit_lst:
        with open(split_f, 'r') as f:
            for line in tqdm(f):
                otup, perm_tup, dist = clean_line(line)
                perm_rep = irrep_dict[
                    perm_tup]  # perm_rep is a dict of (i, j) -> matrix
                block_cyclic_rep = block_cyclic_irreps(otup, cos_reps,
                mult_yor_block(perm_rep, dist, block_cyclic_rep, save_dict)

        if mem_dict is not None:
            mem_dict[pid] = max(check_memory(verbose=False),
                                mem_dict.get(pid, 0))

    block_size = wreath_dim(parts)
    n_cosets = coset_size(alpha)
    mat = convert_yor_matrix(save_dict, block_size, n_cosets)
    return mat
def wreath_yor(alpha, _parts, prefix='/local/hopan/'):
    alpha: weak partition of 8 into 3 parts?
    _parts: list of partitions of each part of alpha
    Return a dict mapping group elmeent in S_8 -> rep
    The rep actually needs to be a dictionary of tuples (i, j) -> matrix
    where the i, j denote the i, j block in the matrix.
        alpha = (0, 0, 0, 0, 1, 1, 1, 1)
        _parts = [(2,2), (3,1)]
    n = sum(alpha)
    _sn = perm2.sn(n, prefix)
    young_sub = young_subgroup_perm(alpha)
    young_sub_set = tup_set(young_sub)
    young_yor = young_subgroup_yor(alpha, _parts, os.path.join(prefix, 'irreps'))
    reps = coset_reps(_sn, young_sub)
    rep_dict = {}

    # this part can be parallelized
    # loop over the group
    # things we need are: group element inv, group element multiplication
    # then grabbing the yor for the appropriate yor thing
    for g in _sn:
        g_rep = {}
        for i, t_i in enumerate(reps):
            for j, t_j in enumerate(reps):
                tiinv_g_tj = t_i.inv() * g * t_j
                if tiinv_g_tj.tup_rep in young_sub_set:
                    g_rep[(i, j)] = young_yor[tiinv_g_tj.tup_rep]

        rep_dict[g.tup_rep] = g_rep 

    return rep_dict
def split_transform(fsplit_lst, irrep_dict, alpha, parts, mem_dict=None):
    fsplit_pkl: list of pkl file names of the distance values for a chunk of the total distance values
    irrep_dict: irrep dict 
    alpha: weak partition
    parts: list/iterable of partitions of the parts of alpha
    print('     Computing transform on splits: {}'.format(fsplit_lst))
    cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha))
    save_dict = {}
    cyc_irrep_func = cyclic_irreps(alpha)
    pid = os.getpid()

    for fsplit_pkl in fsplit_lst:
        with open(fsplit_pkl, 'r') as f:
            # dict of function values
            pkl_dict = load_pkl(fsplit_pkl)
            for perm_tup, tup_dict in pkl_dict.items():
                for tup, dists in tup_dict.items():
                    dist_tot = sum(dists)
                    perm_rep = irrep_dict[
                        perm_tup]  # perm_rep is a dict of (i, j) -> matrix
                    block_cyclic_rep = block_cyclic_irreps(
                        tup, cos_reps, cyc_irrep_func)
                    mult_yor_block(perm_rep, dist_tot, block_cyclic_rep,
            if mem_dict is not None:
                mem_dict[pid] = max(check_memory(verbose=False),
                                    mem_dict.get(pid, 0))
            del pkl_dict

    block_size = wreath_dim(parts)
    n_cosets = coset_size(alpha)
    mat = convert_yor_matrix(save_dict, block_size, n_cosets)
    return mat
def wreath_yor_par(alpha, _parts, prefix='/local/hopan/', par=8):
    alpha: weak partition of 8 into 3 parts?
    _parts: list of partitions of each part of alpha
    Return a dict mapping group elmeent in S_8 -> rep
    The rep actually needs to be a dictionary of tuples (i, j) -> matrix
    where the i, j denote the i, j block in the matrix.
        alpha = (0, 0, 0, 0, 1, 1, 1, 1)
        _parts = [(2,2), (3,1)]
    #print('Wreath yor with {} processes'.format(par))
    n = sum(alpha)
    _sn = perm2.sn(n, prefix)
    young_sub = young_subgroup_perm(alpha)
    young_sub_set = tup_set(young_sub)
    young_yor = young_subgroup_yor(alpha, _parts, os.path.join(prefix, 'irreps'))
    reps = coset_reps(_sn, young_sub)

    sn_chunks = chunk(_sn, par)
    manager = Manager()
    rep_dict = manager.dict()
    nprocs = []

    for i in range(par):
        perms = sn_chunks[i]
        proc = Process(target=_proc_yor, args=[perms, young_yor, young_sub_set, reps, rep_dict])

    for p in nprocs:
    for p in nprocs:

    return rep_dict
def par_cube_ift(rank, size, alpha, parts):
    start = time.time()
        df = load_df('/scratch/hopan/cube/')
        irrep_dict = load_irrep('/scratch/hopan/cube/', alpha, parts)
        fhat = np.load('/scratch/hopan/cube/fourier/{}/{}.npy'.format(
            alpha, parts))
    except Exception as e:
        print('rank {} | memory usg: {} | exception {}'.format(
            rank, check_memory(verbose=False), e))

        'Rank {:3d} / {} | load irrep: {:.2f}s | mem: {:.2f}mb | {} {}'.format(
            rank, size,
            time.time() - start, check_memory(verbose=False), alpha, parts))

    cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha))
    save_dict = {}
    cyc_irrep_func = cyclic_irreps(alpha)

    chunk_size = len(df) // size
    start_idx = chunk_size * rank
    mat = np.zeros(chunk_size, dtype=fhat.dtype)
    fhat_t_ravel = fhat.T.ravel()
    #print('Rank {} | {:7d}-{:7d}'.format(rank, start_idx, start_idx + chunk_size))
    if rank == 0:
            'Rank {} | elapsed: {:.2f}s | {:.2f}mb | mat shape: {} | done load | {} {}'
                    time.time() - start, check_memory(verbose=False),
                    fhat.shape, alpha, parts))

    for idx in range(start_idx, start_idx + chunk_size):
        row = df.loc[idx]
        otup = tuple(int(i) for i in row[0])
        perm_tup = tuple(int(i) for i in row[1])
        #dist = int(row[2])
        # actually want the inverse
        wmat = wreath_rep(otup, perm_tup, irrep_dict, cos_reps, cyc_irrep_func)
        wmat_inv = wmat.conj().T
        # trace(rho(ginv) fhat) = trace(fhat rho(ginv)) = vec(fhat.T).dot(vec(rho(ginv)))
        #feval = np.dot(fhat.T.ravel(), wmat_inv.ravel())
        feval = np.dot(fhat_t_ravel, wmat_inv.ravel())
        mat[idx - start_idx] = fhat.shape[0] * feval

    if rank == 0:
        print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done add'.format(
            time.time() - start, check_memory(verbose=False)))

    del irrep_dict
    if rank == 0:
        print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done matrix conversion'.
                     time.time() - start, check_memory(verbose=False)))

    return mat
def par_cube_ft(rank, size, alpha, parts):
    start = time.time()
        df = load_df('/scratch/hopan/cube/')
        irrep_dict = load_irrep('/scratch/hopan/cube/', alpha, parts)
    except Exception as e:
        print('rank {} | memory usg: {} | exception {}'.format(rank, check_memory(verbose=False), e))

    print('Rank {:3d} / {} | load irrep: {:.2f}s | mem: {}mb'.format(rank, size, time.time() - start, check_memory(verbose=False)))

    cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha))
    save_dict = {}
    cyc_irrep_func = cyclic_irreps(alpha)

    chunk_size = len(df) // size
    start_idx  = chunk_size * rank
    #print('Rank {} | {:7d}-{:7d}'.format(rank, start_idx, start_idx + chunk_size))
    if rank == 0:
        print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done load'.format(rank, time.time() - start, check_memory(verbose=False)))

    for idx in range(start_idx, start_idx + chunk_size):
        row = df.loc[idx]
        otup = tuple(int(i) for i in row[0])
        perm_tup = tuple(int(i) for i in row[1])
        dist = int(row[2])
        perm_rep = irrep_dict[perm_tup]  # perm_rep is a dict of (i, j) -> matrix
        block_cyclic_rep = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func)
        mult_yor_block(perm_rep, dist, block_cyclic_rep, save_dict)

    if rank == 0:
        print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done add'.format(rank, time.time() - start, check_memory(verbose=False)))

    del irrep_dict
    block_size = wreath_dim(parts)
    n_cosets = coset_size(alpha)
    mat = convert_yor_matrix(save_dict, block_size, n_cosets)
    if rank == 0:
        print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done matrix conversion'.format(rank, time.time() - start, check_memory(verbose=False)))

    return mat 
    def __init__(self, alpha, parts, pickledir=None):
        self.alpha = alpha
        self.parts = parts

        self.cos_reps = coset_reps(
            sn(8), young_subgroup_perm(alpha))  # num blocks = num cosets
        self.cyc_irrep_func = cyclic_irreps(alpha)

        # cache cyclic irreps
        self.cyc_irreps = self.compute_cyclic_irreps()

        # load induced perm irrep dict
            fname = os.path.join(pickledir, str(alpha), str(parts) + '.pkl')
            print('Loading pkl from: {}'.format(fname))
            self.ind_irrep_dict = load_pkl(fname)
            self.block_size = self.ind_irrep_dict[
                (1, 2, 3, 4, 5, 6, 7, 8)].shape[0] // len(self.cos_reps)
            end = time.time()
            raise Exception(f'Cant load cube irrep {alpha}, {parts}')
    def test_wreath_full(self):
        o1, p1 = get_wreath('YYRMRMWWRWRYWMYMGGGGBBBB')  # 14
        o2, p2 = get_wreath('YYBWGYRWMRBWMRMGYBRBGGMW')  # 3
        o3, p3 = get_wreath('GGBWMGBBMRYGMRYRYBYRWWWM')  # 4
        w1 = WreathCycSn.from_tup(o1, p1, order=3)
        w2 = WreathCycSn.from_tup(o2, p2, order=3)
        w3 = WreathCycSn.from_tup(o3, p3, order=3)

        prod12 = w1 * w2
        prod13 = w1 * w3
        o12 = prod12.cyc.cyc
        o13 = prod13.cyc.cyc
        perm12 = prod12.perm.tup_rep
        perm13 = prod13.perm.tup_rep

        # load some pickle
        alpha = (2, 3, 3)
        parts = ((1, 1), (1, 1, 1), (2, 1))
        cos_reps = coset_reps(perm2.sn(8), young_subgroup_perm(alpha))
        cyc_irrep_func = cyclic_irreps(alpha)

        start = time.time()
        print('Loading {} | {}'.format(alpha, parts))
        yor_dict = load_irrep('/local/hopan/cube/', alpha, parts)
        if yor_dict is None:
        print('Done loading | {:.2f}s'.format(time.time() - start))

        wreath1 = wreath_rep(o1, p1, yor_dict, cos_reps, cyc_irrep_func, alpha)
        wreath2 = wreath_rep(o2, p2, yor_dict, cos_reps, cyc_irrep_func, alpha)
        wreath3 = wreath_rep(o3, p3, yor_dict, cos_reps, cyc_irrep_func, alpha)
        w12 = np.matmul(wreath1, wreath2)
        w13 = np.matmul(wreath1, wreath3)
        wd12 = wreath_rep(o12, perm12, yor_dict, cos_reps, cyc_irrep_func,
        wd13 = wreath_rep(o13, perm13, yor_dict, cos_reps, cyc_irrep_func,

        self.assertTrue(np.allclose(w12, wd12))
        self.assertTrue(np.allclose(w13, wd13))
 def test_young_coset(self):
     alpha = (2, 2, 2, 1)
     G = perm2.sn(7)
     H = wreath.young_subgroup_perm(alpha)
     self.check_coset(G, H)
def write_feval(feval, saveloc):
    perms = sn(9)
    with open(saveloc, 'w') as f:
        for p in tqdm(perms):
            ptup = p.tup_rep
            f.write('{},{}\n'.format(tup_to_str(ptup), feval(ptup)))
def test_main(alpha, parts):
    Computes the ft via the sparse wreath rep and the non-sparse wreath rep
    to double check that the sparse wreath rep is actually correct.
    _start = time.time()
    st = time.time()
    sp_irrep_dict = load_pkl(
        '/scratch/hopan/cube/pickles_sparse/{}/{}.pkl'.format(alpha, parts))
    end = time.time()
    print('Loading sparse irrep dict: {:.2f}s'.format(time.time() - st))

    st = time.time()
    irrep_dict = load_irrep('/scratch/hopan/cube/', alpha, parts)
    print('Loading irrep dict: {:.2f}s'.format(time.time() - st))

    # generate a random group element?
    st = time.time()
    df = load_df('/scratch/hopan/cube/')
    fhat = np.load('/scratch/hopan/cube/fourier/{}/{}.npy'.format(
        alpha, parts))
    print('Loading df: {:.2f}s'.format(time.time() - st))

    cyc_irrep_func = cyclic_irreps(alpha)
    cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha))
    st = time.time()
    cyc_irrs = all_cyc_irreps(cos_reps, cyc_irrep_func)
    print('Time to compute all cyc irreps: {:.5f}s'.format(time.time() - st))

    sp_times = []
    sp_mult_times = []
    sp_results = np.zeros(len(df), dtype=np.complex128)

    coo_times = []
    th_sp_times = []
    times = []
    mult_times = []
    z3_irreps = []
    results = np.zeros(len(df), dtype=np.complex128)
    fhat_t_ravel = fhat.T.ravel()
    loop_start = time.time()
    for idx in range(len(df)):
        row = df.loc[idx]
        otup = tuple(int(i) for i in row[0])
        perm_tup = tuple(int(i) for i in row[1])

        # compute wreath rep
        st = time.time()
        wmat = wreath_rep(otup, perm_tup, irrep_dict, cos_reps, cyc_irrep_func)
        reg_time = time.time() - st

        # compute wreath rep multiply
        st = time.time()
        wmat_inv = wmat.conj().T
        feval = np.dot(fhat_t_ravel, wmat_inv.ravel())
        reg_mult_time = time.time() - st
        results[idx] = feval

        # compute sparse wreath rep
        st = time.time()
        wmat_sp = wreath_rep_sp(otup, perm_tup, sp_irrep_dict, cos_reps,
                                cyc_irrep_func, cyc_irrs)
        sp_time = time.time() - st

        if not np.allclose(wmat, wmat_sp.todense()):
            print('unequal! | idx = {}'.format(idx))

        # compute sparse wreath rep multiply
        st = time.time()
        wmat_inv_sp = wmat_sp.conj().T
        feval_sp = (wmat_inv_sp.multiply(fhat.T)).sum()
        sp_mult_time = time.time() - st
        sp_results[idx] = feval_sp


        st = time.time()
        coo = wmat_sp.tocoo()
        end = time.time()
        coo_times.append(end - st)

        st = time.time()
        ix = torch.LongTensor([coo.row, coo.col])
        th_sp_re = torch.sparse.FloatTensor(ix,
        th_sp_cplx = torch.sparse.FloatTensor(ix,
        end = time.time()
        th_sp_times.append(end - st)

        st = time.time()
        block_scalars = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func)
        end = time.time()
        z3_irreps.append(end - st)
        if idx > 200:

    print('Normal time: {:.6f}s | Sparse time: {:.6f}s'.format(
        np.mean(times), np.mean(sp_times)))
    print('Mult time:   {:.6f}s | Spmult time: {:.6f}s'.format(
        np.mean(mult_times), np.mean(sp_mult_times)))
    print('To coo time: {:.6f}s | Torchsptime: {:.6f}s'.format(
        np.mean(coo_times), np.mean(th_sp_times)))
    print('irrep time:  {:.6f}s'.format(np.mean(z3_irreps)))
    print('Loop time: {:.2f}s'.format(time.time() - loop_start))
    print('Total time: {:.2f}s'.format(time.time() - _start))
    for g_tup in g_map.keys():
        # grab a group element, hit each H element
        # g * h for h \in H
        for h in H:
            g = g_map[g_tup]
            gh = g * h
            if gh.tup_rep in to_visit:
                # then all of these are good and this is a coset rep
                for _gh in left_coset(g, H):
                # this is a repeat and we can stop
        # continue
        if len(reps) == len(G) / len(H):

    return reps

if __name__ == '__main__':
    G = perm2.sn(4)
    subgroup = young_subgroup_perm((2, 2))
    for rep in coset_reps(G, subgroup):
        for g in left_coset(rep, subgroup):