示例#1
0
def main():
    
    usage = '''
Usage:
-----------------------------------------------------------------------
python %s [-d spatialDimensions] [-p bandPositions [-r resolution ratio]
[-b registration band]  msfilename panfilename 
-----------------------------------------------------------------------
bandPositions and spatialDimensions are lists, 
e.g., -p [1,2,3] -d [0,0,400,400]

Outfile name is msfilename_pan_atwt with same format as msfilename  

Note: PAN image must completely overlap MS image subset chosen    
-----------------------------------------------------''' %sys.argv[0]
    options, args = getopt.getopt(sys.argv[1:],'hd:p:r:b:')
    ratio = 4
    dims1 = None
    pos1 = None  
    k1 = 1          
    for option, value in options:
        if option == '-h':
            print usage
            return 
        elif option == '-r':
            ratio = eval(value)
        elif option == '-d':
            dims1 = eval(value) 
        elif option == '-p':
            pos1 = eval(value)    
        elif option == '-b':
            k1 = eval(value)
    if len(args) != 2:
        print 'Incorrect number of arguments'
        print usage
        sys.exit(1)                         
    gdal.AllRegister()
    file1 = args[0]
    file2 = args[1]   
    path = os.path.dirname(file1)
    basename1 = os.path.basename(file1)
    root1, ext1 = os.path.splitext(basename1)
    outfile = '%s/%s_pan_atwt%s'%(path,root1,ext1)       
#  MS image    
    inDataset1 = gdal.Open(file1,GA_ReadOnly) 
    try:    
        cols = inDataset1.RasterXSize
        rows = inDataset1.RasterYSize    
        bands = inDataset1.RasterCount
    except Exception as e:
        print 'Error: %e --Image could not be read'%e
        sys.exit(1)    
    if pos1 is None:
        pos1 = range(1,bands+1)
    num_bands = len(pos1)    
    if dims1 is None:
        dims1 = [0,0,cols,rows]
    x10,y10,cols1,rows1 = dims1    
#  PAN image    
    inDataset2 = gdal.Open(file2,GA_ReadOnly)   
    try:  
        bands = inDataset2.RasterCount
    except Exception as e:
        print 'Error: %e --Image could not be read'%e   
        sys.exit(1)   
    if bands>1:
        print 'PAN image must be a single band'
        sys.exit(1)     
    geotransform1 = inDataset1.GetGeoTransform()
    geotransform2 = inDataset2.GetGeoTransform()   
    if (geotransform1 is None) or (geotransform2 is None):
        print 'Image not georeferenced, aborting' 
        sys.exit(1)       
    print '========================='
    print '   ATWT Pansharpening'
    print '========================='
    print time.asctime()     
    print 'MS  file: '+file1
    print 'PAN file: '+file2       
#  read in MS image 
    band = inDataset1.GetRasterBand(1)
    tmp = band.ReadAsArray(0,0,1,1)
    dt = tmp.dtype
    MS = np.asarray(np.zeros((num_bands,rows1,cols1)),dtype = dt)
    k = 0                                   
    for b in pos1:
        band = inDataset1.GetRasterBand(b)
        MS[k,:,:] = band.ReadAsArray(x10,y10,cols1,rows1)
        k += 1  
#  if integer assume 11-bit quantization, otherwise must be byte 
    if MS.dtype == np.int16:
        fact = 8.0
        MS = auxil.byteStretch(MS,(0,2**11))      
    else:
        fact = 1.0               
#  read in corresponding spatial subset of PAN image       
    gt1 = list(geotransform1)               
    gt2 = list(geotransform2)
    ulx1 = gt1[0] + x10*gt1[1]
    uly1 = gt1[3] + y10*gt1[5]
    x20 = int(round(((ulx1 - gt2[0])/gt2[1])))
    y20 = int(round(((uly1 - gt2[3])/gt2[5])))
    cols2 = cols1*ratio
    rows2 = rows1*ratio
    band = inDataset2.GetRasterBand(1)
    PAN = band.ReadAsArray(x20,y20,cols2,rows2)
#  if integer assume 11-bit quantization, otherwise must be byte    
    if PAN.dtype == np.int16:
        PAN = auxil.byteStretch(PAN,(0,2**11))    
#  out array    
    sharpened = np.zeros((num_bands,rows2,cols2),dtype=np.float32)          
#  compress PAN to resolution of MS image using DWT  
    panDWT = auxil.DWTArray(PAN,cols2,rows2)          
    r = ratio
    while r > 1:
        panDWT.filter()
        r /= 2
    bn0 = panDWT.get_quadrant(0)   
#  register (and subset) MS image to compressed PAN image using selected MSband  
    lines0,samples0 = bn0.shape    
    bn1 = MS[k1-1,:,:]  
#  register (and subset) MS image to compressed PAN image 
    (scale,angle,shift) = auxil.similarity(bn0,bn1)
    tmp = np.zeros((num_bands,lines0,samples0))
    for k in range(num_bands): 
        bn1 = MS[k,:,:]                    
        bn2 = ndii.zoom(bn1, 1.0/scale)
        bn2 = ndii.rotate(bn2, angle)
        bn2 = ndii.shift(bn2, shift)
        tmp[k,:,:] = bn2[0:lines0,0:samples0]        
    MS = tmp          
    smpl = np.random.randint(cols2*rows2,size=100000)
    print 'Wavelet correlations:'    
#  loop over MS bands
    for k in range(num_bands):
        msATWT = auxil.ATWTArray(PAN)
        r = ratio
        while r > 1:
            msATWT.filter()
            r /= 2 
#      sample PAN wavelet details
        X = msATWT.get_band(msATWT.num_iter)
        X = X.ravel()[smpl]
#      resize the ms band to scale of the pan image
        ms_band = ndii.zoom(MS[k,:,:],ratio)
#      sample details of MS band
        tmpATWT = auxil.ATWTArray(ms_band)
        r = ratio
        while r > 1:
            tmpATWT.filter()
            r /= 2                 
        Y = tmpATWT.get_band(msATWT.num_iter)
        Y = Y.ravel()[smpl]  
#      get band for injection
        bnd = tmpATWT.get_band(0) 
        tmpATWT = None 
        aa,bb,R = auxil.orthoregress(X,Y)
        print 'Band '+str(k+1)+': %8.3f'%R
#      inject the filtered MS band
        msATWT.inject(bnd)    
#      normalize wavelet components and expand
        msATWT.normalize(aa,bb)                    
        r = ratio
        while r > 1:
            msATWT.invert()
            r /= 2 
        sharpened[k,:,:] = msATWT.get_band(0)      
    sharpened *= fact  # rescale dynamic range           
    msATWT = None                              
#  write to disk       
    driver = inDataset1.GetDriver()
    outDataset = driver.Create(outfile,cols2,rows2,num_bands,GDT_Float32)   
    gt1[0] += x10*ratio  
    gt1[3] -= y10*ratio
    gt1[1] = gt2[1]
    gt1[2] = gt2[2]
    gt1[4] = gt2[4]
    gt1[5] = gt2[5]
    outDataset.SetGeoTransform(tuple(gt1))
    projection1 = inDataset1.GetProjection()
    if projection1 is not None:
        outDataset.SetProjection(projection1)        
    for k in range(num_bands):        
        outBand = outDataset.GetRasterBand(k+1)
        outBand.WriteArray(sharpened[k,:,:],0,0) 
        outBand.FlushCache() 
    outDataset = None    
    print 'Result written to %s'%outfile    
    inDataset1 = None
    inDataset2 = None                      
示例#2
0
def register(fn1, fn2, warpband, dims1=None, outfile=None):                  
    gdal.AllRegister()    
    print '--------------------------------'
    print'        Register'   
    print'---------------------------------'      
    print time.asctime()     
    print 'reference image: '+fn1
    print 'warp image: '+fn2     
    print 'warp band: %i'%warpband  
    
    start =  time.time()              
    try:
        if outfile is None:
            path2 = os.path.dirname(fn2)
            basename2 = os.path.basename(fn2)
            root2, ext2 = os.path.splitext(basename2)
            outfile = path2 + '/' + root2 + '_warp' + ext2
        inDataset1 = gdal.Open(fn1,GA_ReadOnly)     
        inDataset2 = gdal.Open(fn2,GA_ReadOnly)
        try:
            cols1 = inDataset1.RasterXSize
            rows1 = inDataset1.RasterYSize    
            cols2 = inDataset2.RasterXSize
            rows2 = inDataset2.RasterYSize    
            bands2 = inDataset2.RasterCount   
        except Exception as e:
            print 'Error %s  --Image could not be read in'%e
            sys.exit(1)     
        if dims1 is None:
            x0 = 0
            y0 = 0
        else:
            x0,y0,cols1,rows1 = dims1    
        
        band = inDataset1.GetRasterBand(warpband)
        refband = band.ReadAsArray(x0,y0,cols1,rows1).astype(np.float32)
        band = inDataset2.GetRasterBand(warpband)
        warpband = band.ReadAsArray(x0,y0,cols1,rows1).astype(np.float32)
        
    #  similarity transform parameters for reference band number            
        scale, angle, shift = similarity(refband, warpband)
    
        driver = inDataset2.GetDriver()
        outDataset = driver.Create(outfile,cols1,rows1,bands2,GDT_Float32)
        projection = inDataset1.GetProjection()
        geotransform = inDataset1.GetGeoTransform()
        if geotransform is not None:
            gt = list(geotransform)
            gt[0] = gt[0] + x0*gt[1]
            gt[3] = gt[3] + y0*gt[5]
            outDataset.SetGeoTransform(tuple(gt))
        if projection is not None:
            outDataset.SetProjection(projection) 
    
    #  warp 
        for k in range(bands2):       
            inband = inDataset2.GetRasterBand(k+1)      
            outBand = outDataset.GetRasterBand(k+1)
            bn1 = inband.ReadAsArray(0,0,cols2,rows2).astype(np.float32)
            bn2 = ndii.zoom(bn1, 1.0 / scale)
            bn2 = ndii.rotate(bn2, angle)
            bn2 = ndii.shift(bn2, shift)       
            outBand.WriteArray(bn2[y0:y0+rows1, x0:x0+cols1]) 
            outBand.FlushCache() 
        inDataset1 = None
        inDataset2 = None
        outDataset = None    
        print 'Warped image written to: %s'%outfile
        print 'elapsed time: %s'%str(time.time()-start)
        return outfile
    except Exception as e:
        print 'registersms failed: %s'%e    
        return None   
示例#3
0
def main():
    usage = '''
Usage:
-----------------------------------------------------------------------
python %s [-d spatialDimensions] [-p bandPositions [-r resolution ratio]
[-b registration band]  msfilename panfilename 
-----------------------------------------------------------------------
bandPositions and spatialDimensions are lists, 
e.g., -p [1,2,3] -d [0,0,400,400]

Outfile name is msfilename_pan_dwt with same format as msfilename    

Note: PAN image must completely overlap MS image subset chosen  
-----------------------------------------------------''' %sys.argv[0]
    options, args = getopt.getopt(sys.argv[1:],'hd:p:r:b:')
    ratio = 4
    dims1 = None
    pos1 = None  
    k1 = 0          
    for option, value in options:
        if option == '-h':
            print usage
            return 
        elif option == '-r':
            ratio = eval(value)
        elif option == '-d':
            dims1 = eval(value) 
        elif option == '-p':
            pos1 = eval(value)    
        elif option == '-b':
            k1 = eval(value)-1
    if len(args) != 2:
        print 'Incorrect number of arguments'
        print usage
        sys.exit(1)                         
    gdal.AllRegister()
    file1 = args[0]
    file2 = args[1]   
    path = os.path.dirname(file1)
    basename1 = os.path.basename(file1)
    root1, ext1 = os.path.splitext(basename1)
    outfile = '%s/%s_pan_dwt%s'%(path,root1,ext1)       
#  MS image    
    inDataset1 = gdal.Open(file1,GA_ReadOnly)     
    try:    
        cols = inDataset1.RasterXSize
        rows = inDataset1.RasterYSize    
        bands = inDataset1.RasterCount
    except Exception as e:
        print 'Error: %e --Image could not be read'%e
        sys.exit(1)    
    if pos1 is None:
        pos1 = range(1,bands+1)
    num_bands = len(pos1)    
    if dims1 is None:
        dims1 = [0,0,cols,rows]
    x10,y10,cols1,rows1 = dims1    
#  PAN image    
    inDataset2 = gdal.Open(file2,GA_ReadOnly)     
    try:  
        bands = inDataset2.RasterCount
    except Exception as e:
        print 'Error: %e --Image could not be read'%e  
        sys.exit(1)   
    if bands>1:
        print 'PAN image must be a single band'
        sys.exit(1)     
    geotransform1 = inDataset1.GetGeoTransform()
    geotransform2 = inDataset2.GetGeoTransform()   
    if (geotransform1 is None) or (geotransform2 is None):
        print 'Image not georeferenced, aborting' 
        sys.exit(1)      
    print '========================='
    print '   DWT Pansharpening'
    print '========================='
    print time.asctime()     
    print 'MS  file: '+file1
    print 'PAN file: '+file2       
#  image arrays
    band = inDataset1.GetRasterBand(1)
    tmp = band.ReadAsArray(0,0,1,1)
    dt = tmp.dtype
    MS = np.asarray(np.zeros((num_bands,rows1,cols1)),dtype=dt) 
    k = 0                                   
    for b in pos1:
        band = inDataset1.GetRasterBand(b)
        MS[k,:,:] = band.ReadAsArray(x10,y10,cols1,rows1)
        k += 1
#  if integer assume 11bit quantization otherwise must be byte   
    if MS.dtype == np.int16:
        fact = 8.0
        MS = auxil.byteStretch(MS,(0,2**11)) 
    else:
        fact = 1.0
#  read in corresponding spatial subset of PAN image    
    if (geotransform1 is None) or (geotransform2 is None):
        print 'Image not georeferenced, aborting' 
        return
#  upper left corner pixel in PAN    
    gt1 = list(geotransform1)               
    gt2 = list(geotransform2)
    ulx1 = gt1[0] + x10*gt1[1]
    uly1 = gt1[3] + y10*gt1[5]
    x20 = int(round(((ulx1 - gt2[0])/gt2[1])))
    y20 = int(round(((uly1 - gt2[3])/gt2[5])))
    cols2 = cols1*ratio
    rows2 = rows1*ratio
    band = inDataset2.GetRasterBand(1)
    PAN = band.ReadAsArray(x20,y20,cols2,rows2)        
#  if integer assume 11-bit quantization, otherwise must be byte    
    if PAN.dtype == np.int16:
        PAN = auxil.byteStretch(PAN,(0,2**11))                                   
#  compress PAN to resolution of MS image  
    panDWT = auxil.DWTArray(PAN,cols2,rows2)          
    r = ratio
    while r > 1:
        panDWT.filter()
        r /= 2
    bn0 = panDWT.get_quadrant(0) 
    lines0,samples0 = bn0.shape    
    bn1 = MS[k1,:,:]  
#  register (and subset) MS image to compressed PAN image 
    (scale,angle,shift) = auxil.similarity(bn0,bn1)
    tmp = np.zeros((num_bands,lines0,samples0))
    for k in range(num_bands): 
        bn1 = MS[k,:,:]                    
        bn2 = ndii.zoom(bn1, 1.0/scale)
        bn2 = ndii.rotate(bn2, angle)
        bn2 = ndii.shift(bn2, shift)
        tmp[k,:,:] = bn2[0:lines0,0:samples0]        
    MS = tmp            
#  compress pan once more, extract wavelet quadrants, and restore
    panDWT.filter()  
    fgpan = panDWT.get_quadrant(1)
    gfpan = panDWT.get_quadrant(2)
    ggpan = panDWT.get_quadrant(3)    
    panDWT.invert()       
#  output array            
    sharpened = np.zeros((num_bands,rows2,cols2),dtype=np.float32)     
    aa = np.zeros(3)
    bb = np.zeros(3)       
    print 'Wavelet correlations:'                                   
    for i in range(num_bands):
#      make copy of panDWT and inject ith ms band                
        msDWT = copy.deepcopy(panDWT)
        msDWT.put_quadrant(MS[i,:,:],0)
#      compress once more                 
        msDWT.filter()
#      determine wavelet normalization coefficents                
        ms = msDWT.get_quadrant(1)    
        aa[0],bb[0],R = auxil.orthoregress(fgpan.ravel(), ms.ravel())
        Rs = 'Band '+str(i+1)+': %8.3f'%R
        ms = msDWT.get_quadrant(2)
        aa[1],bb[1],R = auxil.orthoregress(gfpan.ravel(), ms.ravel())
        Rs += '%8.3f'%R                     
        ms = msDWT.get_quadrant(3)
        aa[2],bb[2],R = auxil.orthoregress(ggpan.ravel(), ms.ravel()) 
        Rs += '%8.3f'%R    
        print Rs         
#      restore once and normalize wavelet coefficients
        msDWT.invert() 
        msDWT.normalize(aa,bb)   
#      restore completely and collect result
        r = 1
        while r < ratio:
            msDWT.invert()
            r *= 2                            
        sharpened[i,:,:] = msDWT.get_quadrant(0)      
    sharpened *= fact    
#  write to disk       
    driver = inDataset1.GetDriver()
    outDataset = driver.Create(outfile,cols2,rows2,num_bands,GDT_Float32)
    projection1 = inDataset1.GetProjection()
    if projection1 is not None:
        outDataset.SetProjection(projection1)        
    gt1 = list(geotransform1)
    gt1[0] += x10*ratio  
    gt1[3] -= y10*ratio
    gt1[1] = gt2[1]
    gt1[2] = gt2[2]
    gt1[4] = gt2[4]
    gt1[5] = gt2[5]
    outDataset.SetGeoTransform(tuple(gt1))   
    for k in range(num_bands):        
        outBand = outDataset.GetRasterBand(k+1)
        outBand.WriteArray(sharpened[k,:,:],0,0) 
        outBand.FlushCache() 
    outDataset = None    
    print 'Result written to %s'%outfile    
    inDataset1 = None
    inDataset2 = None                      
示例#4
0
def register(file0, file1, dims=None, outfile=None):
    import auxil.auxil1 as auxil
    import os, time
    import numpy as np
    from osgeo import gdal
    import scipy.ndimage.interpolation as ndii
    from osgeo.gdalconst import GA_ReadOnly, GDT_Float32

    print('========================= ')
    print('       Register SAR')
    print('=========================')
    print(time.asctime())
    try:
        if outfile is None:
            path = os.path.abspath(file1)
            dirn = os.path.dirname(path)
            path = os.path.dirname(file1)
            basename = os.path.basename(file1)
            root, ext = os.path.splitext(basename)
            outfile = dirn + '/' + root + '_warp' + ext
        start = time.time()
        gdal.AllRegister()
        #  reference
        inDataset0 = gdal.Open(file0, GA_ReadOnly)
        cols = inDataset0.RasterXSize
        rows = inDataset0.RasterYSize
        bands = inDataset0.RasterCount
        print('Reference SAR image:\n %s' % file0)
        if dims == None:
            dims = [0, 0, cols, rows]
        x0, y0, cols, rows = dims
        #  target
        inDataset1 = gdal.Open(file1, GA_ReadOnly)
        cols1 = inDataset1.RasterXSize
        rows1 = inDataset1.RasterYSize
        bands1 = inDataset1.RasterCount
        print('Target SAR image:\n %s' % file1)
        if bands != bands1:
            print('Number of bands must be equal')
            return 0
    #  create the output file
        driver = inDataset1.GetDriver()
        outDataset = driver.Create(outfile, cols, rows, bands, GDT_Float32)
        projection0 = inDataset0.GetProjection()
        geotransform0 = inDataset0.GetGeoTransform()
        geotransform1 = inDataset1.GetGeoTransform()
        gt0 = list(geotransform0)
        gt1 = list(geotransform1)
        if projection0 is not None:
            outDataset.SetProjection(projection0)
    #  find the upper left corner (x0,y0) of reference subset in target (x1,y1)
        ulx0 = gt0[0] + x0 * gt0[1] + y0 * gt0[2]
        uly0 = gt0[3] + x0 * gt0[4] + y0 * gt0[5]
        GT1 = np.mat([[gt1[1], gt1[2]], [gt1[4], gt1[5]]])
        ul1 = np.mat([[ulx0 - gt1[0]], [uly0 - gt1[3]]])
        tmp = GT1.I * ul1
        x1 = int(round(tmp[0, 0]))
        y1 = int(round(tmp[1, 0]))
        #  create output geotransform
        gt1 = gt0
        gt1[0] = ulx0
        gt1[3] = uly0
        outDataset.SetGeoTransform(tuple(gt1))
        #  get matching subsets from geotransform
        rasterBand = inDataset0.GetRasterBand(1)
        span0 = rasterBand.ReadAsArray(x0, y0, cols, rows)
        rasterBand = inDataset1.GetRasterBand(1)
        span1 = rasterBand.ReadAsArray(x1, y1, cols, rows)
        if bands == 9:
            #      get warp parameters using span images
            print('warping 9 bands (quad pol)...')
            rasterBand = inDataset0.GetRasterBand(6)
            span0 += rasterBand.ReadAsArray(x0, y0, cols, rows)
            rasterBand = inDataset0.GetRasterBand(9)
            span0 += rasterBand.ReadAsArray(x0, y0, cols, rows)
            span0 = np.log(np.nan_to_num(span0) + 0.001)
            rasterBand = inDataset1.GetRasterBand(6)
            span1 += rasterBand.ReadAsArray(x1, y1, cols, rows)
            rasterBand = inDataset1.GetRasterBand(9)
            span1 += rasterBand.ReadAsArray(x1, y1, cols, rows)
            span1 = np.log(np.nan_to_num(span1) + 0.001)
            scale, angle, shift = auxil.similarity(span0, span1)
            #      warp the target to the reference and clip
            for k in range(9):
                rasterBand = inDataset1.GetRasterBand(k + 1)
                band = rasterBand.ReadAsArray(0, 0, cols1,
                                              rows1).astype(np.float32)
                bn1 = np.nan_to_num(band)
                bn2 = ndii.zoom(bn1, 1.0 / scale)
                bn2 = ndii.rotate(bn2, angle)
                bn2 = ndii.shift(bn2, shift)
                bn = bn2[y1:y1 + rows, x1:x1 + cols]
                outBand = outDataset.GetRasterBand(k + 1)
                outBand.WriteArray(bn)
                outBand.FlushCache()
        elif bands == 4:
            #      get warp parameters using span images
            print('warping 4 bands (dual pol)...')
            rasterBand = inDataset0.GetRasterBand(4)
            span0 += rasterBand.ReadAsArray(x0, y0, cols, rows)
            span0 = np.log(np.nan_to_num(span0) + 0.001)
            rasterBand = inDataset1.GetRasterBand(4)
            span1 += rasterBand.ReadAsArray(x1, y1, cols, rows)
            span1 = np.log(np.nan_to_num(span1) + 0.001)
            scale, angle, shift = auxil.similarity(span0, span1)
            #      warp the target to the reference and clip
            for k in range(4):
                rasterBand = inDataset1.GetRasterBand(k + 1)
                band = rasterBand.ReadAsArray(0, 0, cols1,
                                              rows1).astype(np.float32)
                bn1 = np.nan_to_num(band)
                bn2 = ndii.zoom(bn1, 1.0 / scale)
                bn2 = ndii.rotate(bn2, angle)
                bn2 = ndii.shift(bn2, shift)
                bn = bn2[y1:y1 + rows, x1:x1 + cols]
                outBand = outDataset.GetRasterBand(k + 1)
                outBand.WriteArray(bn)
                outBand.FlushCache()
        elif bands == 3:
            #      get warp parameters using span images
            print('warping 3 bands (quad pol diagonal)...')
            rasterBand = inDataset0.GetRasterBand(2)
            span0 += rasterBand.ReadAsArray(x0, y0, cols, rows)
            rasterBand = inDataset0.GetRasterBand(3)
            span0 += rasterBand.ReadAsArray(x0, y0, cols, rows)
            span0 = np.log(np.nan_to_num(span0) + 0.001)
            rasterBand = inDataset1.GetRasterBand(2)
            span1 += rasterBand.ReadAsArray(x1, y1, cols, rows)
            rasterBand = inDataset1.GetRasterBand(3)
            span1 += rasterBand.ReadAsArray(x1, y1, cols, rows)
            span1 = np.log(np.nan_to_num(span1) + 0.001)
            scale, angle, shift = auxil.similarity(span0, span1)
            #      warp the target to the reference and clip
            for k in range(3):
                rasterBand = inDataset1.GetRasterBand(k + 1)
                band = rasterBand.ReadAsArray(0, 0, cols1,
                                              rows1).astype(np.float32)
                bn1 = np.nan_to_num(band)
                bn2 = ndii.zoom(bn1, 1.0 / scale)
                bn2 = ndii.rotate(bn2, angle)
                bn2 = ndii.shift(bn2, shift)
                bn = bn2[y1:y1 + rows, x1:x1 + cols]
                outBand = outDataset.GetRasterBand(k + 1)
                outBand.WriteArray(bn)
                outBand.FlushCache()
        elif bands == 2:
            #      get warp parameters using span images
            print('warping 2 bands (dual pol diagonal)...')
            rasterBand = inDataset0.GetRasterBand(2)
            span0 += rasterBand.ReadAsArray(x0, y0, cols, rows)
            span0 = np.log(np.nan_to_num(span0) + 0.001)
            rasterBand = inDataset1.GetRasterBand(2)
            span1 += rasterBand.ReadAsArray(x1, y1, cols, rows)
            span1 = np.log(np.nan_to_num(span1) + 0.001)
            scale, angle, shift = auxil.similarity(span0, span1)
            #      warp the target to the reference and clip
            for k in range(2):
                rasterBand = inDataset1.GetRasterBand(k + 1)
                band = rasterBand.ReadAsArray(0, 0, cols1,
                                              rows1).astype(np.float32)
                bn1 = np.nan_to_num(band)
                bn2 = ndii.zoom(bn1, 1.0 / scale)
                bn2 = ndii.rotate(bn2, angle)
                bn2 = ndii.shift(bn2, shift)
                bn = bn2[y1:y1 + rows, x1:x1 + cols]
                outBand = outDataset.GetRasterBand(k + 1)
                outBand.WriteArray(bn)
                outBand.FlushCache()
        elif bands == 1:
            #      get warp parameters using span images
            print('warping 1 band (single pol)...')
            span0 = np.log(np.nan_to_num(span0) + 0.001)
            span1 = np.log(np.nan_to_num(span1) + 0.001)
            scale, angle, shift = auxil.similarity(span0, span1)
            #      warp the target to the reference and clip
            for k in range(1):
                rasterBand = inDataset1.GetRasterBand(k + 1)
                band = rasterBand.ReadAsArray(0, 0, cols1,
                                              rows1).astype(np.float32)
                bn1 = np.nan_to_num(band)
                bn2 = ndii.zoom(bn1, 1.0 / scale)
                bn2 = ndii.rotate(bn2, angle)
                bn2 = ndii.shift(bn2, shift)
                bn = bn2[y1:y1 + rows, x1:x1 + cols]
                outBand = outDataset.GetRasterBand(k + 1)
                outBand.WriteArray(bn)
                outBand.FlushCache()
        inDataset0 = None
        inDataset1 = None
        outDataset = None
        print('Warped image written to: %s' % outfile)
        print('elapsed time: ' + str(time.time() - start))
        return outfile
    except Exception as e:
        print('registersar failed: %s' % e)
        return None