예제 #1
0
파일: MSBWTGen.py 프로젝트: Rinoahu/msbwt
def mergeNewMSBWT(mergedDir, inputBwtDirs, numProcs, logger):
    '''
    This function will take a list of input BWTs (compressed or not) and merge them into a single BWT
    @param mergedFN - the destination for the final merged MSBWT
    @param inputBWTFN1 - the fn of the first BWT to merge
    @param inputBWTFN2 - the fn of the second BWT to merge
    @param numProcs - number of processes we're allowed to use
    @param logger - output goes here
    '''
    st = time.time()
    iterst = time.time()
    vcLen = 6
    
    #TODO: take advantage of these to skip an iteration or two perhaps
    numInputs = len(inputBwtDirs)
    msbwts = [None]*numInputs
    mergedLength = 0
    for i, dirName in enumerate(inputBwtDirs):
        '''
        NOTE: in practice, since we're allowing for multiprocessing, we construct the FM-index for each input BWT
        simply because in the long run, this allows us to figure out how to start processing chunks separately.
        Without this, we would need to track extra information that really just represent the FM-index.
        '''
        msbwts[i] = MultiStringBWT.loadBWT(dirName, logger)
        mergedLength += msbwts[i].totalSize
    
    #binSize = 2**1#small bin debugging
    #binSize = 2**15#this one is just for test purposes, makes easy to debug things
    #binSize = 2**25#diff in 22-23 is not that much, 23-24 was 8 seconds of difference, so REALLY no diff
    binSize = 2**28
    
    #allocate the mergedBWT space
    logger.info('Allocating space on disk...')
    mergedBWT = np.lib.format.open_memmap(mergedDir+'/msbwt.npy', 'w+', '<u1', (mergedLength,))
    
    #this one will create the array using bits
    logger.info('Initializing iterations...')
    placeArray = np.lib.format.open_memmap(mergedDir+'/temp.0.npy', 'w+', '<u1', (mergedBWT.shape[0],))
    copiedPlaceArray = np.lib.format.open_memmap(mergedDir+'/temp.1.npy', 'w+', '<u1', (mergedBWT.shape[0],))
    start = msbwts[0].totalSize
    end = 0
    
    #fill out the initial array with 0s, 1s, 2s, etc. as our initial condition
    for i, msbwt in enumerate(msbwts):
        end += msbwt.getTotalSize()
        placeArray[start:end].fill(i)
        copiedPlaceArray[start:end].fill(i)
        start = end
    
    #create something to track the offsets
    #TODO: x/binSize + 1 makes one too many bins if it's exactly divisible by binSize, ex: 4 length BWT with binSize 2
    nextBinHasChanged = np.ones(dtype='b', shape=(mergedBWT.shape[0]/binSize+1,))
    prevOffsetCounts = np.zeros(dtype='<u8', shape=(mergedBWT.shape[0]/binSize+1, numInputs))
    currOffsetCounts = np.zeros(dtype='<u8', shape=(mergedBWT.shape[0]/binSize+1, numInputs))
    nextOffsetCounts = np.zeros(dtype='<u8', shape=(mergedBWT.shape[0]/binSize+1, numInputs))
    binUpdates = [{}]*(mergedBWT.shape[0]/binSize+1)
    
    bwtInd = 0
    offsets = [0]*numInputs
    for x in xrange(0, currOffsetCounts.shape[0]):
        #set, then change for next iter
        nextOffsetCounts[x] = offsets
        remaining = binSize
        while remaining > 0 and bwtInd < numInputs:
            if remaining > msbwts[bwtInd].totalSize-offsets[bwtInd]:
                remaining -= msbwts[bwtInd].totalSize-offsets[bwtInd]
                offsets[bwtInd] = msbwts[bwtInd].totalSize
                bwtInd += 1
            else:
                offsets[bwtInd] += remaining
                remaining = 0
    
    ignored = 0
    
    #original
    sys.stdout.write('\rcp ')
    sys.stdout.flush()
        
    del copiedPlaceArray
    needsMoreIterations = True
    
    i = 0
    sameOffsetCount = 0
    while needsMoreIterations:
        prevOffsetCounts = currOffsetCounts
        currOffsetCounts = nextOffsetCounts
        nextOffsetCounts = np.zeros(dtype='<u8', shape=(mergedBWT.shape[0]/binSize+1, numInputs))
        needsMoreIterations = False
        sameOffsetCount = 0
        
        #this method uses a condensed byte and will ignore regions that are already finished
        sys.stdout.write('\rld ')
        sys.stdout.flush()
        ignored = 0
        
        iteret = time.time()
        sys.stdout.write('\r')
        logger.info('Finished iter '+str(i)+' in '+str(iteret-iterst)+'seconds')
        iterst = time.time()
        i += 1
        
        sys.stdout.write('\rld')
        sys.stdout.flush()
        
        #track which bins are actually different
        binHasChanged = nextBinHasChanged
        nextBinHasChanged = np.zeros(dtype='b', shape=(mergedBWT.shape[0]/binSize+1))
        
        tups = []
        
        for x in xrange(0, mergedBWT.shape[0]/binSize + 1):
            #check if the current offset matches the previous iteration offset
            sameOffset = np.array_equal(currOffsetCounts[x], prevOffsetCounts[x])
            
            if sameOffset:
                sameOffsetCount += 1
            
            '''
            TODO: the below False is there because this only works if you do a full file copy right now.  It's
            because unless we copy, then the appropriate parts of the nextPlaceArray isn't properly updated. It's
            unclear whether one of these is better than the other in terms of performance.  File copying is slow, but
            if only a couple sequences are similar then then skipping is good.  I think in general, we only skip at the
            beginning for real data though, so I'm going with the no-skip, no-copy form until I can resolve the
            problem (if there's a resolution).
            '''
            if False and not binHasChanged[x] and sameOffset:
                for key in binUpdates[x]:
                    nextOffsetCounts[key] += binUpdates[x][key]
                ignored += 1
            else:
                #note these are swapped depending on the iteration, saves time since there is no file copying
                if i % 2 == 0:
                    tup = (x, binSize, vcLen, currOffsetCounts[x], mergedDir+'/temp.0.npy', mergedDir+'/temp.1.npy', inputBwtDirs)
                else:
                    tup = (x, binSize, vcLen, currOffsetCounts[x], mergedDir+'/temp.1.npy', mergedDir+'/temp.0.npy', inputBwtDirs)
                tups.append(tup)
        
        if numProcs > 1:
            #TODO: tinker with chunksize, it might matter
            myPool = multiprocessing.Pool(numProcs)
            #myPool = multiprocessing.pool.ThreadPool(numProcs)
            rets = myPool.imap(mergeNewMSBWTPoolCall, tups, chunksize=10)
        else:
            rets = []
            for tup in tups:
                rets.append(mergeNewMSBWTPoolCall(tup))
        
        progressCounter = ignored
        sys.stdout.write('\r'+str(100*progressCounter*binSize/mergedBWT.shape[0])+'%')
        sys.stdout.flush()
            
        for ret in rets:
            #iterate through the returns so we can figure out information necessary for continuation
            (x, nBHC, nOC, nMI) = ret
            binUpdates[x] = nOC
            for k in nBHC:
                nextBinHasChanged[k] |= nBHC[k]
            for b in nOC:
                nextOffsetCounts[b] += nOC[b]
            needsMoreIterations |= nMI
            
            progressCounter += 1
            sys.stdout.write('\r'+str(min(100*progressCounter*binSize/mergedBWT.shape[0], 100))+'%')
            sys.stdout.flush()
        
        nextOffsetCounts = np.cumsum(nextOffsetCounts, axis=0)-nextOffsetCounts
        if numProcs > 1:
            myPool.terminate()
            myPool.join()
            myPool = None
        
    sys.stdout.write('\r')
    sys.stdout.flush()
    logger.info('Order solved, saving final array...')
    
    #TODO: make this better
    offsets = np.zeros(dtype='<u8', shape=(numInputs,))
    for i in xrange(0, mergedBWT.shape[0]/binSize+1):
        ind = placeArray[i*binSize:(i+1)*binSize]
        if i == mergedBWT.shape[0]/binSize:
            ind = ind[0:mergedBWT.shape[0]-i*binSize]
        
        bc = np.bincount(ind, minlength=numInputs)
        
        for x in xrange(0, numInputs):
            mergedBWT[np.add(i*binSize, np.where(ind == x))] = msbwts[x].getBWTRange(int(offsets[x]), int(offsets[x]+bc[x]))
        offsets += bc
        
    et = time.time()
    
    logger.info('Finished all merge iterations in '+str(et-st)+' seconds.')
예제 #2
0
파일: MSBWTGen.py 프로젝트: Rinoahu/msbwt
def mergeNewMSBWTPoolCall(tup):
    '''
    This is a single process call of a chunk of the data to merge BWTs
    @param bID - the block ID this process is processing
    @param binSize - the size of a bin/block
    @param vcLen - 6, hardcoded upstream
    @param currOffsetCounts - the starting position in each input BWT for this process chunk, useful for FM extraction
    @param placeArrayFN - the filename for the input origin bits
    @param nextPlaceArrayFN - the filename for the output origin bits
    @param bwtDirs - the actual input BWT directories to merge
    '''
    (bID, binSize, vcLen, currOffsetCounts, placeArrayFN, nextPlaceArrayFN, bwtDirs) = tup
    
    #load things to run
    placeArray = np.load(placeArrayFN, 'r')
    nextPlaceArray = np.load(nextPlaceArrayFN, 'r+')
    
    numInputs = len(bwtDirs)
    msbwts = [None]*numInputs
    mergedLength = 0
    for i, bwtDir in enumerate(bwtDirs):
        msbwts[i] = MultiStringBWT.loadBWT(bwtDir)
        mergedLength += msbwts[i].totalSize
        
    #state info we need to pass back
    nextBinHasChanged = {}
    nextOffsetCounts = {}
    needsMoreIterations = False
    
    #get the region and count the number of 0s, 1s, 2s, etc.
    region = placeArray[bID*binSize:(bID+1)*binSize]
    srcCounts = np.bincount(region, minlength=numInputs)
    
    #first extract the two subregions from each bwt
    inputIndices = currOffsetCounts
    chunks = [None]*numInputs
    for x in xrange(0, numInputs):
        chunks[x] = msbwts[x].getBWTRange(int(inputIndices[x]), int(inputIndices[x]+srcCounts[x]))
        
    #count the number of characters of each
    bcs = [None]*numInputs
    for x in xrange(0, numInputs):
        bcs[x] = np.bincount(chunks[x], minlength=vcLen)
    
    #interleave these character based on the region
    cArray = np.zeros(dtype='<u1', shape=(region.shape[0],))
    for x in xrange(0, numInputs):
        cArray[region == x] = chunks[x]
    
    #calculate curr using the MSBWT searches
    curr = np.zeros(dtype='<u8', shape=(vcLen,))
    for x in xrange(0, numInputs):
        curr += msbwts[x].getFullFMAtIndex(int(inputIndices[x]))
    
    #this is the equivalent of bin sorting just this small chunk
    for c in xrange(0, vcLen):
        totalC = 0
        for x in xrange(0, numInputs):
            totalC += bcs[x][c]
        
        if totalC == 0:
            continue
        
        #extract the zeroes and ones for this character
        packed = region[cArray == c]
        
        #calculate which bin they are in and mark those bins are changed if different
        b1 = int(math.floor(curr[c]/binSize))
        b2 = int(math.floor((curr[c]+totalC)/binSize))
        if b1 == b2:
            if not np.array_equal(placeArray[curr[c]:curr[c]+packed.shape[0]], packed):
                nextBinHasChanged[b1] = True
                needsMoreIterations = True
            
            origins = np.bincount(packed, minlength=numInputs)
            nextOffsetCounts[b1] = origins+nextOffsetCounts.get(b1, (0,)*numInputs)
            
        else:
            #b1 and b2 are different bins
            delta = b2*binSize-curr[c]
            if not np.array_equal(placeArray[curr[c]:b2*binSize], packed[0:delta]):
                nextBinHasChanged[b1] = True
                needsMoreIterations = True
            
            if not np.array_equal(placeArray[b2*binSize:b2*binSize+packed.shape[0]-delta], packed[delta:]):
                nextBinHasChanged[b2] = True
                needsMoreIterations = True
            
            origins1 = np.bincount(packed[0:delta], minlength=numInputs)
            nextOffsetCounts[b1] = origins1+nextOffsetCounts.get(b1, (0,)*numInputs)
            
            origins2 = np.bincount(packed[delta:], minlength=numInputs)
            nextOffsetCounts[b2] = origins2+nextOffsetCounts.get(b2, (0,)*numInputs)
            
        #this is where we actually do the updating
        nextPlaceArray[curr[c]:curr[c]+totalC] = packed[:]
        
    #cleanup time
    del srcCounts
    del region
    del chunks
    del cArray
    gc.collect()
    
    return (bID, nextBinHasChanged, nextOffsetCounts, needsMoreIterations)