def text_split_transform(fsplit_lst, irrep_dict, alpha, parts, mem_dict=None): ''' fsplit_pkl: list of split file names of the distance values for a chunk of the total distance values irrep_dict: irrep dict alpha: weak partition parts: list/iterable of partitions of the parts of alpha ''' print(' Computing transform on splits: {}'.format(fsplit_lst)) cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha)) save_dict = {} cyc_irrep_func = cyclic_irreps(alpha) pid = os.getpid() for split_f in fsplit_lst: with open(split_f, 'r') as f: for line in tqdm(f): otup, perm_tup, dist = clean_line(line) perm_rep = irrep_dict[ perm_tup] # perm_rep is a dict of (i, j) -> matrix block_cyclic_rep = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func) mult_yor_block(perm_rep, dist, block_cyclic_rep, save_dict) if mem_dict is not None: mem_dict[pid] = max(check_memory(verbose=False), mem_dict.get(pid, 0)) block_size = wreath_dim(parts) n_cosets = coset_size(alpha) mat = convert_yor_matrix(save_dict, block_size, n_cosets) return mat
def split_transform(fsplit_lst, irrep_dict, alpha, parts, mem_dict=None): ''' fsplit_pkl: list of pkl file names of the distance values for a chunk of the total distance values irrep_dict: irrep dict alpha: weak partition parts: list/iterable of partitions of the parts of alpha ''' print(' Computing transform on splits: {}'.format(fsplit_lst)) cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha)) save_dict = {} cyc_irrep_func = cyclic_irreps(alpha) pid = os.getpid() for fsplit_pkl in fsplit_lst: with open(fsplit_pkl, 'r') as f: # dict of function values pkl_dict = load_pkl(fsplit_pkl) for perm_tup, tup_dict in pkl_dict.items(): for tup, dists in tup_dict.items(): dist_tot = sum(dists) perm_rep = irrep_dict[ perm_tup] # perm_rep is a dict of (i, j) -> matrix block_cyclic_rep = block_cyclic_irreps( tup, cos_reps, cyc_irrep_func) mult_yor_block(perm_rep, dist_tot, block_cyclic_rep, save_dict) if mem_dict is not None: mem_dict[pid] = max(check_memory(verbose=False), mem_dict.get(pid, 0)) del pkl_dict block_size = wreath_dim(parts) n_cosets = coset_size(alpha) mat = convert_yor_matrix(save_dict, block_size, n_cosets) return mat
def compute_cyclic_irreps(self): cyc_irreps = {} for otup in cube2_orientations(): o_irrep = block_cyclic_irreps(otup, self.cos_reps, self.cyc_irrep_func) cyc_irreps[otup] = o_irrep return cyc_irreps
def all_cyc_irreps(cos_reps, cyc_irrep_func): otups = [] irreps = {} xs = [(0, 1, 2) for _ in range(8)] opts = product(*xs) opts = [o for o in opts if sum(o) % 3 == 0] for otup in opts: irreps[otup] = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func) return irreps
def tup_to_irrep_np(self, otup, ptup): ''' otup: tuple of orientations in Z/3Z ptup: tuple of permutation of S_8 Returns: numpy matrix ''' block_scalars = block_cyclic_irreps(otup, self.cos_reps, self.cyc_irrep_func) rep = get_mat(ptup, self.yor_dict, block_scalars) return rep
def fill_cyclic_irreps(self): ''' Stores the cyclic irrep for every single 2-cube orientation so fetching a 2-cube state representation becomes two dictionary lookups and a sparse elementwise multiplication. ''' for tup in cube2_orientations(): cyc_irrep = block_cyclic_irreps(tup, self.cos_reps, self.cyc_irrep_func) self.cyclic_irreps_re[tup] = torch.FloatTensor( self.block_pad(cyc_irrep.real)) self.cyclic_irreps_im[tup] = torch.FloatTensor( self.block_pad(cyc_irrep.imag))
def par_cube_ft(rank, size, alpha, parts): start = time.time() try: df = load_df('/scratch/hopan/cube/') irrep_dict = load_irrep('/scratch/hopan/cube/', alpha, parts) except Exception as e: print('rank {} | memory usg: {} | exception {}'.format(rank, check_memory(verbose=False), e)) print('Rank {:3d} / {} | load irrep: {:.2f}s | mem: {}mb'.format(rank, size, time.time() - start, check_memory(verbose=False))) cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha)) save_dict = {} cyc_irrep_func = cyclic_irreps(alpha) chunk_size = len(df) // size start_idx = chunk_size * rank #print('Rank {} | {:7d}-{:7d}'.format(rank, start_idx, start_idx + chunk_size)) if rank == 0: print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done load'.format(rank, time.time() - start, check_memory(verbose=False))) for idx in range(start_idx, start_idx + chunk_size): row = df.loc[idx] otup = tuple(int(i) for i in row[0]) perm_tup = tuple(int(i) for i in row[1]) dist = int(row[2]) perm_rep = irrep_dict[perm_tup] # perm_rep is a dict of (i, j) -> matrix block_cyclic_rep = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func) mult_yor_block(perm_rep, dist, block_cyclic_rep, save_dict) if rank == 0: print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done add'.format(rank, time.time() - start, check_memory(verbose=False))) del irrep_dict block_size = wreath_dim(parts) n_cosets = coset_size(alpha) mat = convert_yor_matrix(save_dict, block_size, n_cosets) if rank == 0: print('Rank {} | elapsed: {:.2f}s | {:.2f}mb | done matrix conversion'.format(rank, time.time() - start, check_memory(verbose=False))) return mat
def tup_to_irrep_th(self, otup, ptup): block_scalars = block_cyclic_irreps(otup, self.cos_reps, self.cyc_irrep_func) rep_re, rep_im = get_sparse_mat(ptup, self.np_yor_dict, block_scalars) return rep_re, rep_im
def test_main(alpha, parts): ''' Computes the ft via the sparse wreath rep and the non-sparse wreath rep to double check that the sparse wreath rep is actually correct. ''' _start = time.time() st = time.time() sp_irrep_dict = load_pkl( '/scratch/hopan/cube/pickles_sparse/{}/{}.pkl'.format(alpha, parts)) end = time.time() print('Loading sparse irrep dict: {:.2f}s'.format(time.time() - st)) check_memory() st = time.time() irrep_dict = load_irrep('/scratch/hopan/cube/', alpha, parts) print('Loading irrep dict: {:.2f}s'.format(time.time() - st)) check_memory() # generate a random group element? st = time.time() df = load_df('/scratch/hopan/cube/') fhat = np.load('/scratch/hopan/cube/fourier/{}/{}.npy'.format( alpha, parts)) print('Loading df: {:.2f}s'.format(time.time() - st)) check_memory() cyc_irrep_func = cyclic_irreps(alpha) cos_reps = coset_reps(sn(8), young_subgroup_perm(alpha)) st = time.time() cyc_irrs = all_cyc_irreps(cos_reps, cyc_irrep_func) print('Time to compute all cyc irreps: {:.5f}s'.format(time.time() - st)) sp_times = [] sp_mult_times = [] sp_results = np.zeros(len(df), dtype=np.complex128) coo_times = [] th_sp_times = [] times = [] mult_times = [] z3_irreps = [] results = np.zeros(len(df), dtype=np.complex128) fhat_t_ravel = fhat.T.ravel() loop_start = time.time() for idx in range(len(df)): row = df.loc[idx] otup = tuple(int(i) for i in row[0]) perm_tup = tuple(int(i) for i in row[1]) # compute wreath rep st = time.time() wmat = wreath_rep(otup, perm_tup, irrep_dict, cos_reps, cyc_irrep_func) reg_time = time.time() - st # compute wreath rep multiply st = time.time() wmat_inv = wmat.conj().T feval = np.dot(fhat_t_ravel, wmat_inv.ravel()) reg_mult_time = time.time() - st results[idx] = feval # compute sparse wreath rep st = time.time() wmat_sp = wreath_rep_sp(otup, perm_tup, sp_irrep_dict, cos_reps, cyc_irrep_func, cyc_irrs) sp_time = time.time() - st if not np.allclose(wmat, wmat_sp.todense()): print('unequal! | idx = {}'.format(idx)) pdb.set_trace() # compute sparse wreath rep multiply st = time.time() wmat_inv_sp = wmat_sp.conj().T feval_sp = (wmat_inv_sp.multiply(fhat.T)).sum() sp_mult_time = time.time() - st sp_results[idx] = feval_sp times.append(reg_time) sp_times.append(sp_time) mult_times.append(reg_mult_time) sp_mult_times.append(sp_mult_time) st = time.time() coo = wmat_sp.tocoo() end = time.time() coo_times.append(end - st) st = time.time() ix = torch.LongTensor([coo.row, coo.col]) th_sp_re = torch.sparse.FloatTensor(ix, torch.FloatTensor(coo.data.real), torch.Size(coo.shape)) th_sp_cplx = torch.sparse.FloatTensor(ix, torch.FloatTensor(coo.data.imag), torch.Size(coo.shape)) end = time.time() th_sp_times.append(end - st) st = time.time() block_scalars = block_cyclic_irreps(otup, cos_reps, cyc_irrep_func) end = time.time() z3_irreps.append(end - st) if idx > 200: break print('Normal time: {:.6f}s | Sparse time: {:.6f}s'.format( np.mean(times), np.mean(sp_times))) print('Mult time: {:.6f}s | Spmult time: {:.6f}s'.format( np.mean(mult_times), np.mean(sp_mult_times))) print('To coo time: {:.6f}s | Torchsptime: {:.6f}s'.format( np.mean(coo_times), np.mean(th_sp_times))) print('irrep time: {:.6f}s'.format(np.mean(z3_irreps))) print('Loop time: {:.2f}s'.format(time.time() - loop_start)) print('Total time: {:.2f}s'.format(time.time() - _start))