示例#1
0
文件: multi.py 项目: dt1483/SnFFT
def text_split_transform(fsplit_lst, irrep_dict, alpha, parts, mem_dict=None):
    '''
    fsplit_pkl: list of split file names of the distance values for a chunk of the total distance values
    irrep_dict: irrep dict
    alpha: weak partition
    parts: list/iterable of partitions of the parts of alpha
    '''
    print('     Computing transform on splits: {}'.format(fsplit_lst))
    cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha))
    save_dict = {}
    cyc_irrep_func = cyclic_irreps(alpha)
    pid = os.getpid()

    for split_f in fsplit_lst:
        with open(split_f, 'r') as f:
            for line in tqdm(f):
                otup, perm_tup, dist = clean_line(line)
                perm_rep = irrep_dict[
                    perm_tup]  # perm_rep is a dict of (i, j) -> matrix
                block_cyclic_rep = block_cyclic_irreps(otup, cos_reps,
                                                       cyc_irrep_func)
                mult_yor_block(perm_rep, dist, block_cyclic_rep, save_dict)

        if mem_dict is not None:
            mem_dict[pid] = max(check_memory(verbose=False),
                                mem_dict.get(pid, 0))

    block_size = wreath_dim(parts)
    n_cosets = coset_size(alpha)
    mat = convert_yor_matrix(save_dict, block_size, n_cosets)
    return mat
示例#2
0
文件: multi.py 项目: dt1483/SnFFT
def split_transform(fsplit_lst, irrep_dict, alpha, parts, mem_dict=None):
    '''
    fsplit_pkl: list of pkl file names of the distance values for a chunk of the total distance values
    irrep_dict: irrep dict 
    alpha: weak partition
    parts: list/iterable of partitions of the parts of alpha
    '''
    print('     Computing transform on splits: {}'.format(fsplit_lst))
    cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha))
    save_dict = {}
    cyc_irrep_func = cyclic_irreps(alpha)
    pid = os.getpid()

    for fsplit_pkl in fsplit_lst:
        with open(fsplit_pkl, 'r') as f:
            # dict of function values
            pkl_dict = load_pkl(fsplit_pkl)
            for perm_tup, tup_dict in pkl_dict.items():
                for tup, dists in tup_dict.items():
                    dist_tot = sum(dists)
                    perm_rep = irrep_dict[
                        perm_tup]  # perm_rep is a dict of (i, j) -> matrix
                    block_cyclic_rep = block_cyclic_irreps(
                        tup, cos_reps, cyc_irrep_func)
                    mult_yor_block(perm_rep, dist_tot, block_cyclic_rep,
                                   save_dict)
            if mem_dict is not None:
                mem_dict[pid] = max(check_memory(verbose=False),
                                    mem_dict.get(pid, 0))
            del pkl_dict

    block_size = wreath_dim(parts)
    n_cosets = coset_size(alpha)
    mat = convert_yor_matrix(save_dict, block_size, n_cosets)
    return mat
示例#3
0
    def compute_cyclic_irreps(self):
        cyc_irreps = {}
        for otup in cube2_orientations():
            o_irrep = block_cyclic_irreps(otup, self.cos_reps,
                                          self.cyc_irrep_func)
            cyc_irreps[otup] = o_irrep

        return cyc_irreps
示例#4
0
def all_cyc_irreps(cos_reps, cyc_irrep_func):
    otups = []
    irreps = {}
    xs = [(0, 1, 2) for _ in range(8)]
    opts = product(*xs)
    opts = [o for o in opts if sum(o) % 3 == 0]
    for otup in opts:
        irreps[otup] = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func)
    return irreps
示例#5
0
 def tup_to_irrep_np(self, otup, ptup):
     '''
     otup: tuple of orientations in Z/3Z
     ptup: tuple of permutation of S_8
     Returns: numpy matrix
     '''
     block_scalars = block_cyclic_irreps(otup, self.cos_reps,
                                         self.cyc_irrep_func)
     rep = get_mat(ptup, self.yor_dict, block_scalars)
     return rep
示例#6
0
 def fill_cyclic_irreps(self):
     '''
     Stores the cyclic irrep for every single 2-cube orientation
     so fetching a 2-cube state representation becomes
     two dictionary lookups and a sparse elementwise multiplication.
     '''
     for tup in cube2_orientations():
         cyc_irrep = block_cyclic_irreps(tup, self.cos_reps,
                                         self.cyc_irrep_func)
         self.cyclic_irreps_re[tup] = torch.FloatTensor(
             self.block_pad(cyc_irrep.real))
         self.cyclic_irreps_im[tup] = torch.FloatTensor(
             self.block_pad(cyc_irrep.imag))
示例#7
0
def par_cube_ft(rank, size, alpha, parts):
    start = time.time()
    try:
        df = load_df('/scratch/hopan/cube/')
        irrep_dict = load_irrep('/scratch/hopan/cube/', alpha, parts)
    except Exception as e:
        print('rank {} | memory usg: {} | exception {}'.format(rank, check_memory(verbose=False), e))

    print('Rank {:3d} / {} | load irrep: {:.2f}s | mem: {}mb'.format(rank, size, time.time() - start, check_memory(verbose=False)))

    cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha))
    save_dict = {}
    cyc_irrep_func = cyclic_irreps(alpha)

    chunk_size = len(df) // size
    start_idx  = chunk_size * rank
    #print('Rank {} | {:7d}-{:7d}'.format(rank, start_idx, start_idx + chunk_size))
    if rank == 0:
        print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done load'.format(rank, time.time() - start, check_memory(verbose=False)))

    for idx in range(start_idx, start_idx + chunk_size):
        row = df.loc[idx]
        otup = tuple(int(i) for i in row[0])
        perm_tup = tuple(int(i) for i in row[1])
        dist = int(row[2])
 
        perm_rep = irrep_dict[perm_tup]  # perm_rep is a dict of (i, j) -> matrix
        block_cyclic_rep = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func)
        mult_yor_block(perm_rep, dist, block_cyclic_rep, save_dict)

    if rank == 0:
        print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done add'.format(rank, time.time() - start, check_memory(verbose=False)))

    del irrep_dict
    block_size = wreath_dim(parts)
    n_cosets = coset_size(alpha)
    mat = convert_yor_matrix(save_dict, block_size, n_cosets)
    if rank == 0:
        print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done matrix conversion'.format(rank, time.time() - start, check_memory(verbose=False)))

    return mat 
示例#8
0
 def tup_to_irrep_th(self, otup, ptup):
     block_scalars = block_cyclic_irreps(otup, self.cos_reps,
                                         self.cyc_irrep_func)
     rep_re, rep_im = get_sparse_mat(ptup, self.np_yor_dict, block_scalars)
     return rep_re, rep_im
示例#9
0
def test_main(alpha, parts):
    '''
    Computes the ft via the sparse wreath rep and the non-sparse wreath rep
    to double check that the sparse wreath rep is actually correct.
    '''
    _start = time.time()
    st = time.time()
    sp_irrep_dict = load_pkl(
        '/scratch/hopan/cube/pickles_sparse/{}/{}.pkl'.format(alpha, parts))
    end = time.time()
    print('Loading sparse irrep dict: {:.2f}s'.format(time.time() - st))
    check_memory()

    st = time.time()
    irrep_dict = load_irrep('/scratch/hopan/cube/', alpha, parts)
    print('Loading irrep dict: {:.2f}s'.format(time.time() - st))
    check_memory()

    # generate a random group element?
    st = time.time()
    df = load_df('/scratch/hopan/cube/')
    fhat = np.load('/scratch/hopan/cube/fourier/{}/{}.npy'.format(
        alpha, parts))
    print('Loading df: {:.2f}s'.format(time.time() - st))
    check_memory()

    cyc_irrep_func = cyclic_irreps(alpha)
    cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha))
    st = time.time()
    cyc_irrs = all_cyc_irreps(cos_reps, cyc_irrep_func)
    print('Time to compute all cyc irreps: {:.5f}s'.format(time.time() - st))

    sp_times = []
    sp_mult_times = []
    sp_results = np.zeros(len(df), dtype=np.complex128)

    coo_times = []
    th_sp_times = []
    times = []
    mult_times = []
    z3_irreps = []
    results = np.zeros(len(df), dtype=np.complex128)
    fhat_t_ravel = fhat.T.ravel()
    loop_start = time.time()
    for idx in range(len(df)):
        row = df.loc[idx]
        otup = tuple(int(i) for i in row[0])
        perm_tup = tuple(int(i) for i in row[1])

        # compute wreath rep
        st = time.time()
        wmat = wreath_rep(otup, perm_tup, irrep_dict, cos_reps, cyc_irrep_func)
        reg_time = time.time() - st

        # compute wreath rep multiply
        st = time.time()
        wmat_inv = wmat.conj().T
        feval = np.dot(fhat_t_ravel, wmat_inv.ravel())
        reg_mult_time = time.time() - st
        results[idx] = feval

        # compute sparse wreath rep
        st = time.time()
        wmat_sp = wreath_rep_sp(otup, perm_tup, sp_irrep_dict, cos_reps,
                                cyc_irrep_func, cyc_irrs)
        sp_time = time.time() - st

        if not np.allclose(wmat, wmat_sp.todense()):
            print('unequal! | idx = {}'.format(idx))
            pdb.set_trace()

        # compute sparse wreath rep multiply
        st = time.time()
        wmat_inv_sp = wmat_sp.conj().T
        feval_sp = (wmat_inv_sp.multiply(fhat.T)).sum()
        sp_mult_time = time.time() - st
        sp_results[idx] = feval_sp

        times.append(reg_time)
        sp_times.append(sp_time)
        mult_times.append(reg_mult_time)
        sp_mult_times.append(sp_mult_time)

        st = time.time()
        coo = wmat_sp.tocoo()
        end = time.time()
        coo_times.append(end - st)

        st = time.time()
        ix = torch.LongTensor([coo.row, coo.col])
        th_sp_re = torch.sparse.FloatTensor(ix,
                                            torch.FloatTensor(coo.data.real),
                                            torch.Size(coo.shape))
        th_sp_cplx = torch.sparse.FloatTensor(ix,
                                              torch.FloatTensor(coo.data.imag),
                                              torch.Size(coo.shape))
        end = time.time()
        th_sp_times.append(end - st)

        st = time.time()
        block_scalars = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func)
        end = time.time()
        z3_irreps.append(end - st)
        if idx > 200:
            break

    print('Normal time: {:.6f}s | Sparse time: {:.6f}s'.format(
        np.mean(times), np.mean(sp_times)))
    print('Mult time:   {:.6f}s | Spmult time: {:.6f}s'.format(
        np.mean(mult_times), np.mean(sp_mult_times)))
    print('To coo time: {:.6f}s | Torchsptime: {:.6f}s'.format(
        np.mean(coo_times), np.mean(th_sp_times)))
    print('irrep time:  {:.6f}s'.format(np.mean(z3_irreps)))
    print('Loop time: {:.2f}s'.format(time.time() - loop_start))
    print('Total time: {:.2f}s'.format(time.time() - _start))