def backProjection(self, index_list): """Perform back projection function for a list of indices. Return a list of high dimensional points (one per index). Check cache for each point and condolidate file I/O for all cache misses. """ logging.debug('-------- BACK PROJECTION: %d POINTS ---', len(index_list)) bench = microbench('bkproj', self.seqNumFromID()) # reverse_index = {index_list[i]: i for i in range(len(index_list))} source_points = [] cache_miss = [] self.trajlist_async = deque() # DEShaw topology is assumed here bench.start() # Derefernce indices to file, frame tuple: historical_framelist = [] generated_framelist = [] if self.xidreference is None: self.xidreference = self.catalog.lrange('xid:reference', 0, -1) # pipe = self.catalog.pipeline() logging.debug('Select Index List size = %d', len(index_list)) for idx in index_list: # Negation indicates historical index: index = int(idx) if index < 0: file_index, frame = deshaw.refFromIndex(-idx) historical_framelist.append((file_index, frame)) # logging.debug('[BP] DEShaw: file #%d, frame#%d', file_index, frame) else: generated_framelist.append(self.xidreference[index]) # pipe.lindex('xid:reference', index) # Load higher dim point indices from catalog # logging.debug('Exectuting...') # start = dt.datetime.now() # generated_framelist = pipe.execute() # logging.debug('...Exectuted in %4.1f sec', ((dt.datetime.now()-start).total_seconds())) # start = dt.datetime.now() # all_idx = self.catalog.lrange('xid:reference', 0, -1) # logging.debug('Got ALL pts in %4.1f sec', ((dt.datetime.now()-start).total_seconds())) bench.mark('BP:LD:Redis:xidlist') ref = deshaw.topo_prot # Hard coded for now # Group all Historical indidces by file number and add to frame Mask logging.debug('Group By file idx (DEshaw)') historical_frameMask = {} for i, idx in enumerate(historical_framelist): file_index, frame = idx if file_index not in historical_frameMask: historical_frameMask[file_index] = [] historical_frameMask[file_index].append(frame) for k, v in historical_frameMask.items(): logging.debug('[BP] Deshaw lookups: %d, %s', k, str(v)) # Group all Generated indidces by file index logging.debug('Group By file idx (Gen data)') groupbyFileIdx = {} for i, idx in enumerate(generated_framelist): file_index, frame = eval(idx) if file_index not in groupbyFileIdx: groupbyFileIdx[file_index] = [] groupbyFileIdx[file_index].append(frame) # Dereference File index to filenames logging.debug('Deref fileidx -> file names') generated_frameMask = {} generated_filemap = {} for file_index in groupbyFileIdx.keys(): filename = self.catalog.lindex('xid:filelist', file_index) if filename is None: logging.error('Error file not found in catalog: %s', filename) else: key = os.path.splitext(os.path.basename(filename))[0] generated_frameMask[key] = groupbyFileIdx[file_index] generated_filemap[key] = filename bench.mark('BP:GroupBy:Files') # Ensure the cache is alive an connected logging.debug('Check Cache client') self.cacheclient.connect() # Check cache for historical data points logging.debug('Checking cache for %d DEShaw points to back-project', len(historical_frameMask.keys())) for fileno, frames in historical_frameMask.items(): # handle 1 frame case (to allow follow on multi-frame, mix cache hit/miss) if len(frames) == 1: datapt = self.cacheclient.get(fileno, frames[0], 'deshaw') dataptlist = [datapt] if datapt is not None else None else: dataptlist = self.cacheclient.get_many(fileno, frames, 'deshaw') if dataptlist is None: self.cache_miss += 1 # logging.debug('[BP] Cache MISS on: %d', fileno) cache_miss.append(('deshaw', fileno, frames)) else: self.cache_hit += 1 # logging.debug('[BP] Cache HIT on: %d', fileno) source_points.extend(dataptlist) # Check cache for generated data points logging.debug('Checking cache for %d Generated points to back-project', len(generated_frameMask.keys())) for filename, frames in generated_frameMask.items(): # handle 1 frame case (to allow follow on multi-frame, mix cache hit/miss) if len(frames) == 1: datapt = self.cacheclient.get(filename, frames[0], 'sim') dataptlist = [datapt] if datapt is not None else None else: dataptlist = self.cacheclient.get_many(filename, frames, 'sim') if dataptlist is None: self.cache_miss += 1 # logging.debug('[BP] Cache MISS on: %s', filename) cache_miss.append(('sim', generated_filemap[filename], frames)) else: self.cache_hit += 1 # logging.debug('[BP] Cache HIT on: %s', filename) source_points.extend(dataptlist) # Package all cached points into one trajectory logging.debug('Cache hits: %d points.', len(source_points)) if len(source_points) > 0: source_traj_cached = md.Trajectory(source_points, ref.top) else: source_traj_cached = None # All files were cached. Return back-projected points if len(cache_miss) == 0: return source_traj_cached # Add high-dim points to list of source points in a trajectory # Optimized for parallel file loading source_points_uncached = [] logging.debug('Sequentially Loading all trajectories') for miss in cache_miss: ftype, fileno, framelist = miss if ftype == 'deshaw': pdb, dcd = deshaw.getHistoricalTrajectory_prot(fileno) traj = md.load(dcd, top=pdb) elif ftype == 'sim': traj = datareduce.load_trajectory(fileno) selected_frames = traj.slice(framelist) source_points_uncached.extend(selected_frames.xyz) bench.mark('BP:LD:File') logging.debug('All Uncached Data collected Total # points = %d', len(source_points_uncached)) source_traj_uncached = md.Trajectory(np.array(source_points_uncached), ref.top) bench.mark('BP:Build:Traj') # bench.show() logging.info('-------- Back Projection Complete ---------------') if source_traj_cached is None: return source_traj_uncached else: return source_traj_cached.join(source_traj_uncached)
def seedJob_Uniform(catalog, num=1, exact=None): """ Seeds jobs into the JCQueue -- pulled from DEShaw Selects equal `num` of randomly start frames from each bin to seed as job candidates """ logging.info('Seeding %d jobs per transtion bin', num) settings = systemsettings() numLabels = int(catalog.get('numLabels')) binlist = [(A, B) for A in range(numLabels) for B in range(numLabels)] dcdfreq = int(catalog.get('dcdfreq')) runtime = int(catalog.get('runtime')) sim_step_size = int(catalog.get('sim_step_size')) if catalog.exists('label:deshaw'): rmslabel = [eval(x) for x in catalog.lrange('label:deshaw', 0, -1)] elif os.path.exists(DESHAW_LABEL_FILE): logging.info('Loading DEShaw Points From File....') with open(DESHAW_LABEL_FILE) as lfile: rmslabel = [eval(label) for label in lfile.read().strip().split('\n')] logging.info('Loaded DEShaw %d Labels from file, %s', len(rmslabel), DESHAW_LABEL_FILE) pipe = catalog.pipeline() for rms in rmslabel: pipe.rpush('label:deshaw', rms) pipe.execute() logging.info('DEShaw Labels stored in the catalog.') else: rmslabel = labelDEShaw_rmsd(store_to_disk=True) pipe = catalog.pipeline() for rms in rmslabel: pipe.rpush('label:deshaw', rms) pipe.execute() logging.info('DEShaw Labels stored in the catalog.') logging.info('Grouping all prelabeled Data:') groupby = {b:[] for b in binlist} for i, b in enumerate(rmslabel): groupby[b].append(i) for k in sorted(groupby.keys()): v = groupby[k] logging.info('%s %7d %4.1f', str(k), len(v), (100*len(v)/len(rmslabel))) if exact is None: source_list = sorted(groupby.keys()) else: bin_list = list(groupby.keys()) if exact <= 25: idx_list = np.random.choice(len(bin_list), exact, replace=False) else: idx_list = np.random.choice(len(bin_list), exact, replace=True) source_list = [bin_list[i] for i in idx_list] for binlabel in source_list: clist = groupby[binlabel] A, B = binlabel # No candidates if len(clist) == 0: logging.info('NO Candidates for %s', str(binlabel)) if binlabel == (1, 3): logging.info('Swapping (1,2) for (1,3)') clist = groupby[(1,2)] B = 2 elif binlabel == (3, 1): logging.info('Swapping (3,0) for (3,1)') clist = groupby[(3,0)] B = 0 else: logging.info('Not sampling this bin') continue for k in range(num): logging.debug('\nSeeding Job #%d for bin (%d,%d) ', k, A, B) index = np.random.choice(clist) src, frame = deshaw.refFromIndex(index) logging.debug(" Selected: BPTI %s, frame: %s", src, frame) pdbfile, dcdfile = deshaw.getHistoricalTrajectory_prot(int(src)) traj = md.load(dcdfile, top=pdbfile, frame=int(frame)) # Generate new set of params/coords jcID, params = generateNewJC(traj) # Update Additional JC Params and Decision History, as needed config = dict(params, name = jcID, runtime = runtime, dcdfreq = dcdfreq, interval = dcdfreq * sim_step_size, temp = 310, timestep = 0, gc = 1, origin = 'deshaw', src_index = index, src_bin = (A, B), src_hcube = 'D', application = settings.APPL_LABEL) logging.info("New Simulation Job Created: %s", jcID) for k, v in config.items(): logging.debug(" %s: %s", k, str(v)) catalog.rpush('jcqueue', jcID) catalog.hmset(wrapKey('jc', jcID), config)