Example #1
0
def make_image(redband, greenband, blueband, rows, cols, enhance):
    X = np.ones((rows * cols, 3), dtype=np.uint8)
    if enhance == 'linear255':
        i = 0
        for tmp in [redband, greenband, blueband]:
            tmp = tmp.ravel()
            X[:, i] = auxil.bytestr(tmp, [0, 255])
            i += 1
    elif enhance == 'linear':
        i = 0
        for tmp in [redband, greenband, blueband]:
            tmp = tmp.ravel()
            X[:, i] = auxil.linstr(tmp)
            i += 1
    elif enhance == 'linear2pc':
        i = 0
        for tmp in [redband, greenband, blueband]:
            tmp = tmp.ravel()
            X[:, i] = auxil.lin2pcstr(tmp)
            i += 1
    elif enhance == 'equalization':
        i = 0
        for tmp in [redband, greenband, blueband]:
            tmp = tmp.ravel()
            X[:, i] = auxil.histeqstr(tmp)
            i += 1
    elif enhance == 'logarithmic':
        i = 0
        for tmp in [redband, greenband, blueband]:
            tmp = tmp.ravel()
            mn = np.min(tmp)
            if mn < 0:
                tmp = tmp - mn
            idx = np.where(tmp == 0)
            tmp[idx] = np.mean(tmp)  # get rid of black edges
            idx = np.where(tmp > 0)
            tmp[idx] = np.log(tmp[idx])
            mn = np.min(tmp)
            mx = np.max(tmp)
            if mx - mn > 0:
                tmp = (tmp - mn) * 255.0 / (mx - mn)
            tmp = np.where(tmp < 0, 0, tmp)
            tmp = np.where(tmp > 255, 255, tmp)
            #          2% linear stretch
            X[:, i] = auxil.lin2pcstr(tmp)
            i += 1
    return np.reshape(X, (rows, cols, 3)) / 255.
Example #2
0
def main():

    usage = '''
Usage: 
--------------------------------------

Panchromatic sharpening with the a trous wavelet transform

python %s [OPTIONS] msfilename panfilename 

Options:
  -h            this help
  -p  <list>    RGB band positions to be sharpened (default all)
                               e.g. -p [1,2,3]
  -d  <list>    spatial subset [x,y,width,height] of ms image
                               e.g. -d [0,0,200,200]
  -r  <int>     resolution ratio ms:pan (default 4)
  -b  <int>     ms band for co-registration 
  
  -------------------------------------'''%sys.argv[0]

    options, args = getopt.getopt(sys.argv[1:],'hd:p:r:b:')
    ratio = 4
    dims1 = None
    pos1 = None  
    k1 = 1          
    for option, value in options:
        if option == '-h':
            print usage
            return 
        elif option == '-r':
            ratio = eval(value)
        elif option == '-d':
            dims1 = eval(value) 
        elif option == '-p':
            pos1 = eval(value)    
        elif option == '-b':
            k1 = eval(value)
    if len(args) != 2:
        print 'Incorrect number of arguments'
        print usage
        sys.exit(1)                         
    gdal.AllRegister()
    file1 = args[0]
    file2 = args[1]   
    path = os.path.dirname(file1)
    basename1 = os.path.basename(file1)
    root1, ext1 = os.path.splitext(basename1)
    outfile = '%s/%s_pan_atwt%s'%(path,root1,ext1)       
#  MS image    
    inDataset1 = gdal.Open(file1,GA_ReadOnly) 
    try:    
        cols = inDataset1.RasterXSize
        rows = inDataset1.RasterYSize    
        bands = inDataset1.RasterCount
    except Exception as e:
        print 'Error: %e --Image could not be read'%e
        sys.exit(1)    
    if pos1 is None:
        pos1 = range(1,bands+1)
    num_bands = len(pos1)    
    if dims1 is None:
        dims1 = [0,0,cols,rows]
    x10,y10,cols1,rows1 = dims1    
#  PAN image    
    inDataset2 = gdal.Open(file2,GA_ReadOnly)   
    try:  
        bands = inDataset2.RasterCount
    except Exception as e:
        print 'Error: %e --Image could not be read'%e   
        sys.exit(1)   
    if bands>1:
        print 'PAN image must be a single band'
        sys.exit(1)     
    geotransform1 = inDataset1.GetGeoTransform()
    geotransform2 = inDataset2.GetGeoTransform()   
    if (geotransform1 is None) or (geotransform2 is None):
        print 'Image not georeferenced, aborting' 
        sys.exit(1)       
    print '========================='
    print '   ATWT Pansharpening'
    print '========================='
    print time.asctime()     
    print 'MS  file: '+file1
    print 'PAN file: '+file2       
#  read in MS image 
    band = inDataset1.GetRasterBand(1)
    tmp = band.ReadAsArray(0,0,1,1)
    dt = tmp.dtype
    MS = np.asarray(np.zeros((num_bands,rows1,cols1)),dtype = dt)
    k = 0                                   
    for b in pos1:
        band = inDataset1.GetRasterBand(b)
        MS[k,:,:] = band.ReadAsArray(x10,y10,cols1,rows1)
        k += 1  
#  if integer assume 11-bit quantization, otherwise must be byte 
    if MS.dtype == np.int16:
        fact = 8.0
        MS = auxil.bytestr(MS,(0,2**11))      
    else:
        fact = 1.0               
#  read in corresponding spatial subset of PAN image       
    gt1 = list(geotransform1)               
    gt2 = list(geotransform2)
    ulx1 = gt1[0] + x10*gt1[1]
    uly1 = gt1[3] + y10*gt1[5]
    x20 = int(round(((ulx1 - gt2[0])/gt2[1])))
    y20 = int(round(((uly1 - gt2[3])/gt2[5])))
    cols2 = cols1*ratio
    rows2 = rows1*ratio
    band = inDataset2.GetRasterBand(1)
    PAN = band.ReadAsArray(x20,y20,cols2,rows2)
#  if integer assume 11-bit quantization, otherwise must be byte    
    if PAN.dtype == np.int16:
        PAN = auxil.bytestr(PAN,(0,2**11))    
#  out array    
    sharpened = np.zeros((num_bands,rows2,cols2),dtype=np.float32)          
#  compress PAN to resolution of MS image using DWT  
    panDWT = auxil.DWTArray(PAN,cols2,rows2)          
    r = ratio
    while r > 1:
        panDWT.filter()
        r /= 2
    bn0 = panDWT.get_quadrant(0)   
#  register (and subset) MS image to compressed PAN image using selected MSband  
    lines0,samples0 = bn0.shape    
    bn1 = MS[k1-1,:,:]  
#  register (and subset) MS image to compressed PAN image 
    (scale,angle,shift) = auxil.similarity(bn0,bn1)
    tmp = np.zeros((num_bands,lines0,samples0))
    for k in range(num_bands): 
        bn1 = MS[k,:,:]                    
        bn2 = ndii.zoom(bn1, 1.0/scale)
        bn2 = ndii.rotate(bn2, angle)
        bn2 = ndii.shift(bn2, shift)
        tmp[k,:,:] = bn2[0:lines0,0:samples0]        
    MS = tmp          
    smpl = np.random.randint(cols2*rows2,size=100000)
    print 'Wavelet correlations:'    
#  loop over MS bands
    for k in range(num_bands):
        msATWT = auxil.ATWTArray(PAN)
        r = ratio
        while r > 1:
            msATWT.filter()
            r /= 2 
#      sample PAN wavelet details
        X = msATWT.get_band(msATWT.num_iter)
        X = X.ravel()[smpl]
#      resize the ms band to scale of the pan image
        ms_band = ndii.zoom(MS[k,:,:],ratio)
#      sample details of MS band
        tmpATWT = auxil.ATWTArray(ms_band)
        r = ratio
        while r > 1:
            tmpATWT.filter()
            r /= 2                 
        Y = tmpATWT.get_band(msATWT.num_iter)
        Y = Y.ravel()[smpl]  
#      get band for injection
        bnd = tmpATWT.get_band(0) 
        tmpATWT = None 
        aa,bb,R = auxil.orthoregress(X,Y)
        print 'Band '+str(k+1)+': %8.3f'%R
#      inject the filtered MS band
        msATWT.inject(bnd)    
#      normalize wavelet components and expand
        msATWT.normalize(aa,bb)                    
        r = ratio
        while r > 1:
            msATWT.invert()
            r /= 2 
        sharpened[k,:,:] = msATWT.get_band(0)      
    sharpened *= fact  # rescale dynamic range           
    msATWT = None                              
#  write to disk       
    driver = inDataset1.GetDriver()
    outDataset = driver.Create(outfile,cols2,rows2,num_bands,GDT_Float32)   
    gt1[0] += x10*ratio  
    gt1[3] -= y10*ratio
    gt1[1] = gt2[1]
    gt1[2] = gt2[2]
    gt1[4] = gt2[4]
    gt1[5] = gt2[5]
    outDataset.SetGeoTransform(tuple(gt1))
    projection1 = inDataset1.GetProjection()
    if projection1 is not None:
        outDataset.SetProjection(projection1)        
    for k in range(num_bands):        
        outBand = outDataset.GetRasterBand(k+1)
        outBand.WriteArray(sharpened[k,:,:],0,0) 
        outBand.FlushCache() 
    outDataset = None    
    print 'Result written to %s'%outfile    
    inDataset1 = None
    inDataset2 = None                      
Example #3
0
def main():

    usage = '''
Usage: 
--------------------------------------

Segment a multispectral image with mean shift 

python %s [OPTIONS] filename

Options:
  -h            this help
  -p  <list>    band positions e.g. -p [1,2,3,4,5,7]
  -d  <list>    spatial subset [x,y,width,height] 
                              e.g. -d [0,0,200,200]
  -r  <int>     spectral bandwidth (default 15)
  -s  <int>     spatial bandwidth (default 15)
  -m  <int>     minimum segment size (default 30) 

  -------------------------------------''' % sys.argv[0]

    options, args = getopt.getopt(sys.argv[1:], 'hs:r:m:d:p:')
    dims = None
    pos = None
    hs = 15
    hr = 15
    minseg = 30
    for option, value in options:
        if option == '-h':
            print usage
            return
        elif option == '-d':
            dims = eval(value)
        elif option == '-p':
            pos = eval(value)
        elif option == '-s':
            hs = eval(value)
        elif option == '-r':
            hr = eval(value)
        elif option == '-m':
            minseg = eval(value)
    gdal.AllRegister()
    infile = args[0]
    inDataset = gdal.Open(infile, GA_ReadOnly)
    nc = inDataset.RasterXSize
    nr = inDataset.RasterYSize
    nb = inDataset.RasterCount
    if dims:
        x0, y0, nc, nr = dims
    else:
        x0 = 0
        y0 = 0
    if pos is not None:
        nb = len(pos)
    else:
        pos = range(1, nb + 1)
    m = nc * nr
    path = os.path.dirname(infile)
    basename = os.path.basename(infile)
    root, ext = os.path.splitext(basename)
    outfile = path + '/' + root + '_meanshift' + ext
    print '========================='
    print '    mean shift'
    print '========================='
    print 'infile: %s' % infile
    start = time.time()
    #  input image
    data = np.zeros((nr, nc, nb + 2), dtype=np.float)
    k = 0
    for b in pos:
        band = inDataset.GetRasterBand(b)
        data[:, :, k] = auxil.bytestr(band.ReadAsArray(x0, y0, nc, nr))
        k += 1
#  normalize spatial/spectral
    data = data * hs / hr
    ij = np.array(range(nr * nc))
    data[:, :, nb] = np.reshape(ij % nc, (nr, nc))  # x-coord of (i,j) = j
    data[:, :, nb + 1] = np.reshape(ij / nc, (nr, nc))  # y-coord of (i,j) = i
    modes = [np.zeros(nb + 2)]
    labeled = np.zeros(m, dtype=np.int)
    idx = 0
    idx_max = 1000
    label = 0
    #  loop over all pixels
    print 'filtering pixels...'
    while idx < m:
        mode, cpts, cpts_max = mean_shift(data, idx, hs, nc, nr, nb)
        idx_max = max(idx_max, cpts_max)
        #      squared distance to nearest neighbor
        dd = np.sum((mode - modes)**2, 1)
        d2 = np.min(dd)
        #      label of nearest neighbor
        l_nn = np.argmin(dd)
        #      indices of pixels to be labeled
        indices = idx + np.intersect1d(
            np.where(cpts[idx:idx_max] > 0)[0],
            np.where(labeled[idx:idx_max] == 0)[0])
        count = indices.size
        if count > 0:
            #          label pixels
            if ((count < minseg) or (d2 < hs**2)) and (l_nn != 0):
                labeled[indices] = l_nn
            else:
                modes = np.append(modes, [mode], axis=0)
                labeled[indices] = label
                label += 1
#          find the next unlabeled pixel
            nxt = idx + np.where(labeled[idx:idx_max] == 0)[0]
            count = nxt.size
            if count > 0:
                idx = np.min(nxt)
            else:
                #              done
                idx = m
        else:
            idx += 1


#  write to disk
    driver = gdal.GetDriverByName('GTiff')
    outDataset = driver.Create(outfile, nc, nr, nb + 2, GDT_Float32)
    projection = inDataset.GetProjection()
    geotransform = inDataset.GetGeoTransform()
    if geotransform is not None:
        gt = list(geotransform)
        gt[0] = gt[0] + x0 * gt[1]
        gt[3] = gt[3] + y0 * gt[5]
        outDataset.SetGeoTransform(tuple(gt))
    if projection is not None:
        outDataset.SetProjection(projection)

    labeled = filters.median_filter(np.reshape(labeled, (nr, nc)), 3)
    boundaries = np.zeros(m)

    xx = (labeled - np.roll(labeled, (1, 0))).ravel()
    yy = (labeled - np.roll(labeled, (0, 1))).ravel()
    idx1 = np.where(xx != 0)[0]
    idx2 = np.where(yy != 0)[0]
    idx = np.union1d(idx1, idx2)
    boundaries[idx] = 255

    labeled = np.reshape(labeled, m)

    filtered = np.zeros((m, nb))
    labels = modes.shape[0]
    for lbl in range(labels):
        indices = np.where(labeled == lbl)[0]
        filtered[indices, :] = modes[lbl, :nb]

    for k in range(nb):
        outBand = outDataset.GetRasterBand(k + 1)
        outBand.WriteArray(np.reshape(filtered[:, k], (nr, nc)), 0, 0)
        outBand.FlushCache()
    outBand = outDataset.GetRasterBand(nb + 1)
    outBand.WriteArray(np.reshape(labeled, (nr, nc)), 0, 0)
    outBand.FlushCache()
    outBand = outDataset.GetRasterBand(nb + 2)
    outBand.WriteArray(np.reshape(boundaries, (nr, nc)), 0, 0)
    outBand.FlushCache()

    outDataset = None
    inDataset = None
    print 'result written to: ' + outfile
    print 'elapsed time: ' + str(time.time() - start)
Example #4
0
def main():
    usage = '''
Usage: 
--------------------------------------

Panchromatic sharpening with the a discrete transform

python %s [OPTIONS] msfilename panfilename 

Options:
  -h            this help
  -p  <list>    RGB band positions to be sharpened (default all)
                               e.g. -p [1,2,3]
  -d  <list>    spatial subset [x,y,width,height] of ms image
                               e.g. -d [0,0,200,200]
  -r  <int>     resolution ratio ms:pan (default 4)
  -b  <int>     ms band for co-registration 
  
  -------------------------------------'''%sys.argv[0]
  
    options, args = getopt.getopt(sys.argv[1:],'hd:p:r:b:')
    ratio = 4
    dims1 = None
    pos1 = None  
    k1 = 0          
    for option, value in options:
        if option == '-h':
            print usage
            return 
        elif option == '-r':
            ratio = eval(value)
        elif option == '-d':
            dims1 = eval(value) 
        elif option == '-p':
            pos1 = eval(value)    
        elif option == '-b':
            k1 = eval(value)-1
    if len(args) != 2:
        print 'Incorrect number of arguments'
        print usage
        sys.exit(1)                         
    gdal.AllRegister()
    file1 = args[0]
    file2 = args[1]   
    path = os.path.dirname(file1)
    basename1 = os.path.basename(file1)
    root1, ext1 = os.path.splitext(basename1)
    outfile = '%s/%s_pan_dwt%s'%(path,root1,ext1)       
#  MS image    
    inDataset1 = gdal.Open(file1,GA_ReadOnly)     
    try:    
        cols = inDataset1.RasterXSize
        rows = inDataset1.RasterYSize    
        bands = inDataset1.RasterCount
    except Exception as e:
        print 'Error: %e --Image could not be read'%e
        sys.exit(1)    
    if pos1 is None:
        pos1 = range(1,bands+1)
    num_bands = len(pos1)    
    if dims1 is None:
        dims1 = [0,0,cols,rows]
    x10,y10,cols1,rows1 = dims1    
#  PAN image    
    inDataset2 = gdal.Open(file2,GA_ReadOnly)     
    try:  
        bands = inDataset2.RasterCount
    except Exception as e:
        print 'Error: %e --Image could not be read'%e  
        sys.exit(1)   
    if bands>1:
        print 'PAN image must be a single band'
        sys.exit(1)     
    geotransform1 = inDataset1.GetGeoTransform()
    geotransform2 = inDataset2.GetGeoTransform()   
    if (geotransform1 is None) or (geotransform2 is None):
        print 'Image not georeferenced, aborting' 
        sys.exit(1)      
    print '========================='
    print '   DWT Pansharpening'
    print '========================='
    print time.asctime()     
    print 'MS  file: '+file1
    print 'PAN file: '+file2       
#  image arrays
    band = inDataset1.GetRasterBand(1)
    tmp = band.ReadAsArray(0,0,1,1)
    dt = tmp.dtype
    MS = np.asarray(np.zeros((num_bands,rows1,cols1)),dtype=dt) 
    k = 0                                   
    for b in pos1:
        band = inDataset1.GetRasterBand(b)
        MS[k,:,:] = band.ReadAsArray(x10,y10,cols1,rows1)
        k += 1
#  if integer assume 11bit quantization otherwise must be byte   
    if MS.dtype == np.int16:
        fact = 8.0
        MS = auxil.bytestr(MS,(0,2**11)) 
    else:
        fact = 1.0
#  read in corresponding spatial subset of PAN image    
    if (geotransform1 is None) or (geotransform2 is None):
        print 'Image not georeferenced, aborting' 
        return
#  upper left corner pixel in PAN    
    gt1 = list(geotransform1)               
    gt2 = list(geotransform2)
    ulx1 = gt1[0] + x10*gt1[1]
    uly1 = gt1[3] + y10*gt1[5]
    x20 = int(round(((ulx1 - gt2[0])/gt2[1])))
    y20 = int(round(((uly1 - gt2[3])/gt2[5])))
    cols2 = cols1*ratio
    rows2 = rows1*ratio
    band = inDataset2.GetRasterBand(1)
    PAN = band.ReadAsArray(x20,y20,cols2,rows2)        
#  if integer assume 11-bit quantization, otherwise must be byte    
    if PAN.dtype == np.int16:
        PAN = auxil.bytestr(PAN,(0,2**11))                                   
#  compress PAN to resolution of MS image  
    panDWT = auxil.DWTArray(PAN,cols2,rows2)          
    r = ratio
    while r > 1:
        panDWT.filter()
        r /= 2
    bn0 = panDWT.get_quadrant(0) 
    lines0,samples0 = bn0.shape    
    bn1 = MS[k1,:,:]  
#  register (and subset) MS image to compressed PAN image 
    (scale,angle,shift) = auxil.similarity(bn0,bn1)
    tmp = np.zeros((num_bands,lines0,samples0))
    for k in range(num_bands): 
        bn1 = MS[k,:,:]                    
        bn2 = ndii.zoom(bn1, 1.0/scale)
        bn2 = ndii.rotate(bn2, angle)
        bn2 = ndii.shift(bn2, shift)
        tmp[k,:,:] = bn2[0:lines0,0:samples0]        
    MS = tmp            
#  compress pan once more, extract wavelet quadrants, and restore
    panDWT.filter()  
    fgpan = panDWT.get_quadrant(1)
    gfpan = panDWT.get_quadrant(2)
    ggpan = panDWT.get_quadrant(3)    
    panDWT.invert()       
#  output array            
    sharpened = np.zeros((num_bands,rows2,cols2),dtype=np.float32)     
    aa = np.zeros(3)
    bb = np.zeros(3)       
    print 'Wavelet correlations:'                                   
    for i in range(num_bands):
#      make copy of panDWT and inject ith ms band                
        msDWT = copy.deepcopy(panDWT)
        msDWT.put_quadrant(MS[i,:,:],0)
#      compress once more                 
        msDWT.filter()
#      determine wavelet normalization coefficents                
        ms = msDWT.get_quadrant(1)    
        aa[0],bb[0],R = auxil.orthoregress(fgpan.ravel(), ms.ravel())
        Rs = 'Band '+str(i+1)+': %8.3f'%R
        ms = msDWT.get_quadrant(2)
        aa[1],bb[1],R = auxil.orthoregress(gfpan.ravel(), ms.ravel())
        Rs += '%8.3f'%R                     
        ms = msDWT.get_quadrant(3)
        aa[2],bb[2],R = auxil.orthoregress(ggpan.ravel(), ms.ravel()) 
        Rs += '%8.3f'%R    
        print Rs         
#      restore once and normalize wavelet coefficients
        msDWT.invert() 
        msDWT.normalize(aa,bb)   
#      restore completely and collect result
        r = 1
        while r < ratio:
            msDWT.invert()
            r *= 2                            
        sharpened[i,:,:] = msDWT.get_quadrant(0)      
    sharpened *= fact    
#  write to disk       
    driver = inDataset1.GetDriver()
    outDataset = driver.Create(outfile,cols2,rows2,num_bands,GDT_Float32)
    projection1 = inDataset1.GetProjection()
    if projection1 is not None:
        outDataset.SetProjection(projection1)        
    gt1 = list(geotransform1)
    gt1[0] += x10*ratio  
    gt1[3] -= y10*ratio
    gt1[1] = gt2[1]
    gt1[2] = gt2[2]
    gt1[4] = gt2[4]
    gt1[5] = gt2[5]
    outDataset.SetGeoTransform(tuple(gt1))   
    for k in range(num_bands):        
        outBand = outDataset.GetRasterBand(k+1)
        outBand.WriteArray(sharpened[k,:,:],0,0) 
        outBand.FlushCache() 
    outDataset = None    
    print 'Result written to %s'%outfile    
    inDataset1 = None
    inDataset2 = None