def __init__(self, alpha, parts, numpy=False, sparse=True): ''' alpha: tuple of ints, the weak partition of 8 into 3 parts parts: tuple of ints, the partitions of each part of alpha cached_loc: string, location of the cached pickle file of S_8 mod S_alpha irreps ''' self.alpha = alpha self.parts = parts self.cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha)) self.cyc_irrep_func = cyclic_irreps(alpha) self.yor_dict = None # cache orientation tuple -> cyclic irrep self.cyclic_irreps_re = {} self.cyclic_irreps_im = {} self.fill_cyclic_irreps() # also cache the cyclic irreps if numpy: pkl_loc = IRREP_LOC_FMT.format(alpha, parts) self.yor_dict = load_pkl(pkl_loc) elif sparse: pkl_loc = IRREP_SP_LOC_FMT.format(alpha, parts) self.yor_dict = load_sparse_pkl(pkl_loc) else: # TODO: deprecate print('neither sparse nor numpy') pkl_loc = IRREP_LOC_FMT.format(alpha, parts) self.np_yor_dict = load_pkl(pkl_loc)
def text_split_transform(fsplit_lst, irrep_dict, alpha, parts, mem_dict=None): ''' fsplit_pkl: list of split file names of the distance values for a chunk of the total distance values irrep_dict: irrep dict alpha: weak partition parts: list/iterable of partitions of the parts of alpha ''' print(' Computing transform on splits: {}'.format(fsplit_lst)) cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha)) save_dict = {} cyc_irrep_func = cyclic_irreps(alpha) pid = os.getpid() for split_f in fsplit_lst: with open(split_f, 'r') as f: for line in tqdm(f): otup, perm_tup, dist = clean_line(line) perm_rep = irrep_dict[ perm_tup] # perm_rep is a dict of (i, j) -> matrix block_cyclic_rep = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func) mult_yor_block(perm_rep, dist, block_cyclic_rep, save_dict) if mem_dict is not None: mem_dict[pid] = max(check_memory(verbose=False), mem_dict.get(pid, 0)) block_size = wreath_dim(parts) n_cosets = coset_size(alpha) mat = convert_yor_matrix(save_dict, block_size, n_cosets) return mat
def split_transform(fsplit_lst, irrep_dict, alpha, parts, mem_dict=None): ''' fsplit_pkl: list of pkl file names of the distance values for a chunk of the total distance values irrep_dict: irrep dict alpha: weak partition parts: list/iterable of partitions of the parts of alpha ''' print(' Computing transform on splits: {}'.format(fsplit_lst)) cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha)) save_dict = {} cyc_irrep_func = cyclic_irreps(alpha) pid = os.getpid() for fsplit_pkl in fsplit_lst: with open(fsplit_pkl, 'r') as f: # dict of function values pkl_dict = load_pkl(fsplit_pkl) for perm_tup, tup_dict in pkl_dict.items(): for tup, dists in tup_dict.items(): dist_tot = sum(dists) perm_rep = irrep_dict[ perm_tup] # perm_rep is a dict of (i, j) -> matrix block_cyclic_rep = block_cyclic_irreps( tup, cos_reps, cyc_irrep_func) mult_yor_block(perm_rep, dist_tot, block_cyclic_rep, save_dict) if mem_dict is not None: mem_dict[pid] = max(check_memory(verbose=False), mem_dict.get(pid, 0)) del pkl_dict block_size = wreath_dim(parts) n_cosets = coset_size(alpha) mat = convert_yor_matrix(save_dict, block_size, n_cosets) return mat
def par_cube_ift(rank, size, alpha, parts): start = time.time() try: df = load_df('/scratch/hopan/cube/') irrep_dict = load_irrep('/scratch/hopan/cube/', alpha, parts) fhat = np.load('/scratch/hopan/cube/fourier/{}/{}.npy'.format( alpha, parts)) except Exception as e: print('rank {} | memory usg: {} | exception {}'.format( rank, check_memory(verbose=False), e)) print( 'Rank {:3d} / {} | load irrep: {:.2f}s | mem: {:.2f}mb | {} {}'.format( rank, size, time.time() - start, check_memory(verbose=False), alpha, parts)) cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha)) save_dict = {} cyc_irrep_func = cyclic_irreps(alpha) chunk_size = len(df) // size start_idx = chunk_size * rank mat = np.zeros(chunk_size, dtype=fhat.dtype) fhat_t_ravel = fhat.T.ravel() #print('Rank {} | {:7d}-{:7d}'.format(rank, start_idx, start_idx + chunk_size)) if rank == 0: print( 'Rank {} | elapsed: {:.2f}s | {:.2f}mb | mat shape: {} | done load | {} {}' .format(rank, time.time() - start, check_memory(verbose=False), fhat.shape, alpha, parts)) for idx in range(start_idx, start_idx + chunk_size): row = df.loc[idx] otup = tuple(int(i) for i in row[0]) perm_tup = tuple(int(i) for i in row[1]) #dist = int(row[2]) # actually want the inverse wmat = wreath_rep(otup, perm_tup, irrep_dict, cos_reps, cyc_irrep_func) wmat_inv = wmat.conj().T # trace(rho(ginv) fhat) = trace(fhat rho(ginv)) = vec(fhat.T).dot(vec(rho(ginv))) #feval = np.dot(fhat.T.ravel(), wmat_inv.ravel()) feval = np.dot(fhat_t_ravel, wmat_inv.ravel()) mat[idx - start_idx] = fhat.shape[0] * feval if rank == 0: print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done add'.format( rank, time.time() - start, check_memory(verbose=False))) del irrep_dict if rank == 0: print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done matrix conversion'. format(rank, time.time() - start, check_memory(verbose=False))) return mat
def par_cube_ft(rank, size, alpha, parts): start = time.time() try: df = load_df('/scratch/hopan/cube/') irrep_dict = load_irrep('/scratch/hopan/cube/', alpha, parts) except Exception as e: print('rank {} | memory usg: {} | exception {}'.format(rank, check_memory(verbose=False), e)) print('Rank {:3d} / {} | load irrep: {:.2f}s | mem: {}mb'.format(rank, size, time.time() - start, check_memory(verbose=False))) cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha)) save_dict = {} cyc_irrep_func = cyclic_irreps(alpha) chunk_size = len(df) // size start_idx = chunk_size * rank #print('Rank {} | {:7d}-{:7d}'.format(rank, start_idx, start_idx + chunk_size)) if rank == 0: print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done load'.format(rank, time.time() - start, check_memory(verbose=False))) for idx in range(start_idx, start_idx + chunk_size): row = df.loc[idx] otup = tuple(int(i) for i in row[0]) perm_tup = tuple(int(i) for i in row[1]) dist = int(row[2]) perm_rep = irrep_dict[perm_tup] # perm_rep is a dict of (i, j) -> matrix block_cyclic_rep = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func) mult_yor_block(perm_rep, dist, block_cyclic_rep, save_dict) if rank == 0: print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done add'.format(rank, time.time() - start, check_memory(verbose=False))) del irrep_dict block_size = wreath_dim(parts) n_cosets = coset_size(alpha) mat = convert_yor_matrix(save_dict, block_size, n_cosets) if rank == 0: print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done matrix conversion'.format(rank, time.time() - start, check_memory(verbose=False))) return mat
def __init__(self, alpha, parts, pickledir=None): self.alpha = alpha self.parts = parts self.cos_reps = coset_reps( sn(8), young_subgroup_perm(alpha)) # num blocks = num cosets self.cyc_irrep_func = cyclic_irreps(alpha) # cache cyclic irreps self.cyc_irreps = self.compute_cyclic_irreps() # load induced perm irrep dict try: fname = os.path.join(pickledir, str(alpha), str(parts) + '.pkl') print('Loading pkl from: {}'.format(fname)) self.ind_irrep_dict = load_pkl(fname) self.block_size = self.ind_irrep_dict[ (1, 2, 3, 4, 5, 6, 7, 8)].shape[0] // len(self.cos_reps) end = time.time() except: raise Exception(f'Cant load cube irrep {alpha}, {parts}')
def test_wreath_full(self): o1, p1 = get_wreath('YYRMRMWWRWRYWMYMGGGGBBBB') # 14 o2, p2 = get_wreath('YYBWGYRWMRBWMRMGYBRBGGMW') # 3 o3, p3 = get_wreath('GGBWMGBBMRYGMRYRYBYRWWWM') # 4 w1 = WreathCycSn.from_tup(o1, p1, order=3) w2 = WreathCycSn.from_tup(o2, p2, order=3) w3 = WreathCycSn.from_tup(o3, p3, order=3) prod12 = w1 * w2 prod13 = w1 * w3 o12 = prod12.cyc.cyc o13 = prod13.cyc.cyc perm12 = prod12.perm.tup_rep perm13 = prod13.perm.tup_rep # load some pickle alpha = (2, 3, 3) parts = ((1, 1), (1, 1, 1), (2, 1)) cos_reps = coset_reps(perm2.sn(8), young_subgroup_perm(alpha)) cyc_irrep_func = cyclic_irreps(alpha) start = time.time() print('Loading {} | {}'.format(alpha, parts)) yor_dict = load_irrep('/local/hopan/cube/', alpha, parts) if yor_dict is None: exit() print('Done loading | {:.2f}s'.format(time.time() - start)) wreath1 = wreath_rep(o1, p1, yor_dict, cos_reps, cyc_irrep_func, alpha) wreath2 = wreath_rep(o2, p2, yor_dict, cos_reps, cyc_irrep_func, alpha) wreath3 = wreath_rep(o3, p3, yor_dict, cos_reps, cyc_irrep_func, alpha) w12 = np.matmul(wreath1, wreath2) w13 = np.matmul(wreath1, wreath3) wd12 = wreath_rep(o12, perm12, yor_dict, cos_reps, cyc_irrep_func, alpha) wd13 = wreath_rep(o13, perm13, yor_dict, cos_reps, cyc_irrep_func, alpha) self.assertTrue(np.allclose(w12, wd12)) self.assertTrue(np.allclose(w13, wd13))
def test_main(alpha, parts): ''' Computes the ft via the sparse wreath rep and the non-sparse wreath rep to double check that the sparse wreath rep is actually correct. ''' _start = time.time() st = time.time() sp_irrep_dict = load_pkl( '/scratch/hopan/cube/pickles_sparse/{}/{}.pkl'.format(alpha, parts)) end = time.time() print('Loading sparse irrep dict: {:.2f}s'.format(time.time() - st)) check_memory() st = time.time() irrep_dict = load_irrep('/scratch/hopan/cube/', alpha, parts) print('Loading irrep dict: {:.2f}s'.format(time.time() - st)) check_memory() # generate a random group element? st = time.time() df = load_df('/scratch/hopan/cube/') fhat = np.load('/scratch/hopan/cube/fourier/{}/{}.npy'.format( alpha, parts)) print('Loading df: {:.2f}s'.format(time.time() - st)) check_memory() cyc_irrep_func = cyclic_irreps(alpha) cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha)) st = time.time() cyc_irrs = all_cyc_irreps(cos_reps, cyc_irrep_func) print('Time to compute all cyc irreps: {:.5f}s'.format(time.time() - st)) sp_times = [] sp_mult_times = [] sp_results = np.zeros(len(df), dtype=np.complex128) coo_times = [] th_sp_times = [] times = [] mult_times = [] z3_irreps = [] results = np.zeros(len(df), dtype=np.complex128) fhat_t_ravel = fhat.T.ravel() loop_start = time.time() for idx in range(len(df)): row = df.loc[idx] otup = tuple(int(i) for i in row[0]) perm_tup = tuple(int(i) for i in row[1]) # compute wreath rep st = time.time() wmat = wreath_rep(otup, perm_tup, irrep_dict, cos_reps, cyc_irrep_func) reg_time = time.time() - st # compute wreath rep multiply st = time.time() wmat_inv = wmat.conj().T feval = np.dot(fhat_t_ravel, wmat_inv.ravel()) reg_mult_time = time.time() - st results[idx] = feval # compute sparse wreath rep st = time.time() wmat_sp = wreath_rep_sp(otup, perm_tup, sp_irrep_dict, cos_reps, cyc_irrep_func, cyc_irrs) sp_time = time.time() - st if not np.allclose(wmat, wmat_sp.todense()): print('unequal! | idx = {}'.format(idx)) pdb.set_trace() # compute sparse wreath rep multiply st = time.time() wmat_inv_sp = wmat_sp.conj().T feval_sp = (wmat_inv_sp.multiply(fhat.T)).sum() sp_mult_time = time.time() - st sp_results[idx] = feval_sp times.append(reg_time) sp_times.append(sp_time) mult_times.append(reg_mult_time) sp_mult_times.append(sp_mult_time) st = time.time() coo = wmat_sp.tocoo() end = time.time() coo_times.append(end - st) st = time.time() ix = torch.LongTensor([coo.row, coo.col]) th_sp_re = torch.sparse.FloatTensor(ix, torch.FloatTensor(coo.data.real), torch.Size(coo.shape)) th_sp_cplx = torch.sparse.FloatTensor(ix, torch.FloatTensor(coo.data.imag), torch.Size(coo.shape)) end = time.time() th_sp_times.append(end - st) st = time.time() block_scalars = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func) end = time.time() z3_irreps.append(end - st) if idx > 200: break print('Normal time: {:.6f}s | Sparse time: {:.6f}s'.format( np.mean(times), np.mean(sp_times))) print('Mult time: {:.6f}s | Spmult time: {:.6f}s'.format( np.mean(mult_times), np.mean(sp_mult_times))) print('To coo time: {:.6f}s | Torchsptime: {:.6f}s'.format( np.mean(coo_times), np.mean(th_sp_times))) print('irrep time: {:.6f}s'.format(np.mean(z3_irreps))) print('Loop time: {:.2f}s'.format(time.time() - loop_start)) print('Total time: {:.2f}s'.format(time.time() - _start))