Ejemplo n.º 1
0
Archivo: ctl_mv.py Proyecto: DaMSL/ddc
    def backProjection(self, index_list):
      """Perform back projection function for a list of indices. Return a list 
      of high dimensional points (one per index). Check cache for each point and
      condolidate file I/O for all cache misses.
      """

      logging.debug('--------  BACK PROJECTION:  %d POINTS ---', len(index_list))
      bench = microbench('bkproj', self.seqNumFromID())

      # reverse_index = {index_list[i]: i for i in range(len(index_list))}

      source_points = []
      cache_miss = []

      self.trajlist_async = deque()
      
      # DEShaw topology is assumed here
      bench.start()

      # Derefernce indices to file, frame tuple:
      historical_framelist = []
      generated_framelist = []

      if self.xidreference is None:
        self.xidreference = self.catalog.lrange('xid:reference', 0, -1)

      # pipe = self.catalog.pipeline()
      logging.debug('Select Index List size = %d', len(index_list))
      for idx in index_list:
        # Negation indicates  historical index:
        index = int(idx)
        if index < 0:
          file_index, frame = deshaw.refFromIndex(-idx)
          historical_framelist.append((file_index, frame))
          # logging.debug('[BP] DEShaw:  file #%d,   frame#%d', file_index, frame)
        else:
          generated_framelist.append(self.xidreference[index])
          # pipe.lindex('xid:reference', index)

      # Load higher dim point indices from catalog
      # logging.debug('Exectuting...')  
      # start = dt.datetime.now()
      # generated_framelist = pipe.execute()
      # logging.debug('...Exectuted in %4.1f sec', ((dt.datetime.now()-start).total_seconds()))  

      # start = dt.datetime.now()
      # all_idx = self.catalog.lrange('xid:reference', 0, -1)
      # logging.debug('Got ALL pts in %4.1f sec', ((dt.datetime.now()-start).total_seconds()))  


      bench.mark('BP:LD:Redis:xidlist')


      ref = deshaw.topo_prot  # Hard coded for now

      # Group all Historical indidces by file number and add to frame Mask 
      logging.debug('Group By file idx (DEshaw)')
      historical_frameMask = {}
      for i, idx in enumerate(historical_framelist):
        file_index, frame = idx
        if file_index not in historical_frameMask:
          historical_frameMask[file_index] = []
        historical_frameMask[file_index].append(frame)

      for k, v in historical_frameMask.items():
        logging.debug('[BP] Deshaw lookups: %d, %s', k, str(v))


      # Group all Generated indidces by file index 
      logging.debug('Group By file idx (Gen data)')
      groupbyFileIdx = {}
      for i, idx in enumerate(generated_framelist):
        file_index, frame = eval(idx)
        if file_index not in groupbyFileIdx:
          groupbyFileIdx[file_index] = []
        groupbyFileIdx[file_index].append(frame)

      # Dereference File index to filenames
      logging.debug('Deref fileidx -> file names')
      generated_frameMask = {}
      generated_filemap = {}
      for file_index in groupbyFileIdx.keys():
        filename = self.catalog.lindex('xid:filelist', file_index)
        if filename is None:
          logging.error('Error file not found in catalog: %s', filename)
        else:
          key = os.path.splitext(os.path.basename(filename))[0]
          generated_frameMask[key] = groupbyFileIdx[file_index]
          generated_filemap[key] = filename
      bench.mark('BP:GroupBy:Files')

      #  Ensure the cache is alive an connected
      logging.debug('Check Cache client')
      self.cacheclient.connect()

      # Check cache for historical data points
      logging.debug('Checking cache for %d DEShaw points to back-project', len(historical_frameMask.keys()))
      for fileno, frames in historical_frameMask.items():
        # handle 1 frame case (to allow follow on multi-frame, mix cache hit/miss)
        if len(frames) == 1:
          datapt = self.cacheclient.get(fileno, frames[0], 'deshaw')
          dataptlist = [datapt] if datapt is not None else None
        else:
          dataptlist = self.cacheclient.get_many(fileno, frames, 'deshaw')
        if dataptlist is None:
          self.cache_miss += 1
          # logging.debug('[BP] Cache MISS on: %d', fileno)
          cache_miss.append(('deshaw', fileno, frames))
        else:
          self.cache_hit += 1
          # logging.debug('[BP] Cache HIT on: %d', fileno)
          source_points.extend(dataptlist)

      # Check cache for generated data points
      logging.debug('Checking cache for %d Generated points to back-project', len(generated_frameMask.keys()))
      for filename, frames in generated_frameMask.items():
        # handle 1 frame case (to allow follow on multi-frame, mix cache hit/miss)
        if len(frames) == 1:
          datapt = self.cacheclient.get(filename, frames[0], 'sim')
          dataptlist = [datapt] if datapt is not None else None
        else:
          dataptlist = self.cacheclient.get_many(filename, frames, 'sim')
        if dataptlist is None:
          self.cache_miss += 1
          # logging.debug('[BP] Cache MISS on: %s', filename)
          cache_miss.append(('sim', generated_filemap[filename], frames))
        else:
          self.cache_hit += 1
          # logging.debug('[BP] Cache HIT on: %s', filename)
          source_points.extend(dataptlist)


      # Package all cached points into one trajectory
      logging.debug('Cache hits: %d points.', len(source_points))
      if len(source_points) > 0:
        source_traj_cached = md.Trajectory(source_points, ref.top)
      else:
        source_traj_cached = None

      # All files were cached. Return back-projected points
      if len(cache_miss) == 0:
        return source_traj_cached
        
      # Add high-dim points to list of source points in a trajectory
      # Optimized for parallel file loading
      source_points_uncached = []
      logging.debug('Sequentially Loading all trajectories')
      for miss in cache_miss:
        ftype, fileno, framelist = miss
        if ftype == 'deshaw':
          pdb, dcd = deshaw.getHistoricalTrajectory_prot(fileno)
          traj = md.load(dcd, top=pdb)
        elif ftype == 'sim':
          traj = datareduce.load_trajectory(fileno)
        selected_frames = traj.slice(framelist)
        source_points_uncached.extend(selected_frames.xyz)
        bench.mark('BP:LD:File')

      logging.debug('All Uncached Data collected Total # points = %d', len(source_points_uncached))
      source_traj_uncached = md.Trajectory(np.array(source_points_uncached), ref.top)
      bench.mark('BP:Build:Traj')
      # bench.show()

      logging.info('--------  Back Projection Complete ---------------')
      if source_traj_cached is None:
        return source_traj_uncached
      else:
        return source_traj_cached.join(source_traj_uncached)
Ejemplo n.º 2
0
def seedJob_Uniform(catalog, num=1, exact=None):
  """
  Seeds jobs into the JCQueue -- pulled from DEShaw
  Selects equal `num` of randomly start frames from each bin
  to seed as job candidates
  """
  logging.info('Seeding %d jobs per transtion bin', num)
  settings = systemsettings()
  numLabels = int(catalog.get('numLabels'))
  binlist = [(A, B) for A in range(numLabels) for B in range(numLabels)]

  dcdfreq = int(catalog.get('dcdfreq'))
  runtime = int(catalog.get('runtime'))
  sim_step_size = int(catalog.get('sim_step_size'))

  if catalog.exists('label:deshaw'):
    rmslabel = [eval(x) for x in catalog.lrange('label:deshaw', 0, -1)]
  elif os.path.exists(DESHAW_LABEL_FILE):
    logging.info('Loading DEShaw Points From File....')
    with open(DESHAW_LABEL_FILE) as lfile:
      rmslabel = [eval(label) for label in lfile.read().strip().split('\n')]
    logging.info('Loaded DEShaw %d Labels from file, %s', len(rmslabel), DESHAW_LABEL_FILE)
    pipe = catalog.pipeline()
    for rms in rmslabel:
      pipe.rpush('label:deshaw', rms)
    pipe.execute()
    logging.info('DEShaw Labels stored in the catalog.')
  else:
    rmslabel = labelDEShaw_rmsd(store_to_disk=True)
    pipe = catalog.pipeline()
    for rms in rmslabel:
      pipe.rpush('label:deshaw', rms)
    pipe.execute()
    logging.info('DEShaw Labels stored in the catalog.')
  logging.info('Grouping all prelabeled Data:')
  groupby = {b:[] for b in binlist}

  for i, b in enumerate(rmslabel):
    groupby[b].append(i)

  for k in sorted(groupby.keys()):
    v = groupby[k]
    logging.info('%s %7d %4.1f', str(k), len(v), (100*len(v)/len(rmslabel)))

  if exact is None:
    source_list = sorted(groupby.keys())
  else:
    bin_list = list(groupby.keys())
    if exact <= 25:
      idx_list = np.random.choice(len(bin_list), exact, replace=False)
    else:
      idx_list = np.random.choice(len(bin_list), exact, replace=True)
    source_list = [bin_list[i] for i in idx_list]

  for binlabel in source_list:
    clist = groupby[binlabel]
    A, B = binlabel

    # No candidates
    if len(clist) == 0:
      logging.info('NO Candidates for %s', str(binlabel))
      if binlabel == (1, 3):
        logging.info('Swapping (1,2) for (1,3)')
        clist = groupby[(1,2)]
        B = 2
      elif binlabel == (3, 1):
        logging.info('Swapping (3,0) for (3,1)')
        clist = groupby[(3,0)]
        B = 0
      else:
        logging.info('Not sampling this bin')
        continue

    for k in range(num):
      logging.debug('\nSeeding Job #%d for bin (%d,%d) ', k, A, B)
      index = np.random.choice(clist)
      src, frame = deshaw.refFromIndex(index)
      logging.debug("   Selected: BPTI %s, frame: %s", src, frame)
      pdbfile, dcdfile = deshaw.getHistoricalTrajectory_prot(int(src))
      traj = md.load(dcdfile, top=pdbfile, frame=int(frame))

      # Generate new set of params/coords
      jcID, params = generateNewJC(traj)

      # Update Additional JC Params and Decision History, as needed
      config = dict(params,
          name    = jcID,
          runtime = runtime,
          dcdfreq = dcdfreq,
          interval = dcdfreq * sim_step_size,                       
          temp    = 310,
          timestep = 0,
          gc      = 1,
          origin  = 'deshaw',
          src_index = index,
          src_bin  = (A, B),
          src_hcube = 'D',
          application   = settings.APPL_LABEL)
      logging.info("New Simulation Job Created: %s", jcID)
      for k, v in config.items():
        logging.debug("   %s:  %s", k, str(v))
      catalog.rpush('jcqueue', jcID)
      catalog.hmset(wrapKey('jc', jcID), config)