예제 #1
0
def main(ifile, geo_file, path_out):
    '''
    @description: AQUA/MODIS SeaDAS大气校正
    @ifile {str} L1B文件
    @geo_file {str} 地理定位文件
    @path_out {str} 输出路径
    @return: 
    '''
    date_str = os.path.split(ifile)[1].split('.')[1]
    date_str = date_teanslator.jd_to_cale(date_str[1:])
    time_str = os.path.split(ifile)[1].split('.')[2]
    year = int(date_str.split('.')[0])
    month = int(date_str.split('.')[1])
    day = int(date_str.split('.')[2])
    hour = int(time_str[0:2])
    minute = int(time_str[2:])
    date_str = '%d%02d%02d%02d%02d%02d' % (year, month, day, hour + 8, minute,
                                           0)
    nrow = os.path.split(ifile)[1].split('.')[3]
    out_name = 'TERRA_MODIS_1000_L2_%s_%s_00.hdf' % (date_str, nrow)
    raster_fn_out = os.path.join(path_out, out_name)
    cmd = 'l2gen ifile=%s geofile=%s cloud_thresh=0.05 ofile=%s' % \
        (ifile, geo_file, raster_fn_out)
    os.system(cmd)
    return (raster_fn_out)
예제 #2
0
def main(ifile,
         shp_file,
         center_lonlat,
         salz,
         sala,
         satellite,
         aerotype=1,
         altitude=0.01,
         visibility=15,
         band_need=['all'],
         path_out_radi=None,
         path_out_6s=None):
    '''
    @description: 主程序
    @ifile {str}: L1B swath文件 
    @shp_file {str}: 研究区域矢量文件
    @center_lonlat {[lon, lat]}: 中心经纬度
    @aerotype {int}: 气溶胶类型(默认大陆型)
    @altitude {float}: 海拔(km)
    @visibility {float}: 能见度(km)
    @band_need {list}: 所选波段, all表示全部处理
    @path_6s {str}: 大气校正结果输出路径
    @return: 
    '''
    # 设置环境变量
    os.environ['MRTDATADIR'] = global_config['MRTDATADIR']
    os.environ['PGSHOME'] = global_config['PGSHOME']
    os.environ['MRTBINDIR'] = global_config['MRTBINDIR']
    run_path = os.path.split(ifile)[0]
    # step 1: 获取hdf信息
    SD_file = SD(ifile)
    sds_obj = SD_file.select('EV_250_Aggr500_RefSB')
    sds_info = sds_obj.attributes()
    scales = sds_info['radiance_scales']
    offsets = sds_info['radiance_offsets']
    sds_obj = SD_file.select('EV_500_RefSB')
    sds_info = sds_obj.attributes()
    scales_t = sds_info['radiance_scales']
    offsets_t = sds_info['radiance_offsets']
    for i in range(len(scales_t)):
        scales.append(scales_t[i])
        offsets.append(offsets_t[i])
    SD_file.end()
    nbands = len(scales)
    date_str = os.path.split(ifile)[1].split('.')[1]
    time_str = os.path.split(ifile)[1].split('.')[2]
    date_str = date_teanslator.jd_to_cale(date_str[1:])
    year = int(date_str.split('.')[0])
    month = int(date_str.split('.')[1])
    day = int(date_str.split('.')[2])
    hour = int(time_str[0:2])
    minute = int(time_str[2:])
    date = '%d/%d/%d %d:%d:00' % (year, month, day, hour, minute)
    sola_position = calc_sola_position.main(center_lonlat[0], center_lonlat[1],
                                            date)
    solz = sola_position[0]
    sola = sola_position[1]
    heg_bin = os.path.join(global_config['MRTBINDIR'], 'hegtool')
    os.system('cd %s && %s -h %s > heg.log' % (run_path, heg_bin, ifile))
    info_file = os.path.join(run_path, 'HegHdr.hdr')
    if not (os.path.exists(info_file)):
        print('获取文件信息出错:%s' % ifile)
        return ('')
    else:
        lat_min = None
        lat_max = None
        lon_min = None
        lon_max = None
        pixel_x = None
        pixel_y = None
        with open(info_file, 'r') as fp:
            lines = fp.readlines()
            for line in lines:
                if 'SWATH_LAT_MIN' in line:
                    lat_min = float(re.findall(r'[\d.]+', line)[0])
                elif 'SWATH_LAT_MAX' in line:
                    lat_max = float(re.findall(r'[\d.]+', line)[0])
                elif 'SWATH_LON_MIN' in line:
                    lon_min = float(re.findall(r'[\d.]+', line)[0])
                elif 'SWATH_LON_MAX' in line:
                    lon_max = float(re.findall(r'[\d.]+', line)[0])
                elif 'SWATH_X_PIXEL_RES_DEGREES' in line:
                    pixel_x = float(re.findall(r'[\d.]+', line)[0])
                elif 'SWATH_Y_PIXEL_RES_DEGREES' in line:
                    pixel_y = float(re.findall(r'[\d.]+', line)[0])
        if lat_min and lat_max and lon_min and lon_max and pixel_x and pixel_y:
            prm_file = os.path.join(run_path, 'HegSwath.prm')
            # 重投影
            for i_band in range(nbands):
                modis_band = 'B%d' % (i_band + 1)
                if (modis_band in band_need) or ('all' in band_need):
                    print('reprojection %s ...' % modis_band)
                    out_file = ifile.replace('.hdf',
                                             '_reproj_' + modis_band + '.tif')
                    if '.hdf' in ifile:
                        if os.path.exists(out_file):
                            os.remove(out_file)
                    else:
                        print('无法识别的文件格式:%s' % os.path.split(ifile)[1])
                    with open(prm_file, 'wb') as fp:
                        fp.write(b'\nNUM_RUNS = 1\n\n')
                        fp.write(b'BEGIN\n')
                        fp.write(
                            bytes('INPUT_FILENAME = %s\n' % ifile, 'utf-8'))
                        fp.write(b'OBJECT_NAME = MODIS_SWATH_Type_L1B\n')
                        if i_band >= 0 and i_band <= 1:
                            fp.write(b'FIELD_NAME = EV_250_Aggr500_RefSB|\n')
                            fp.write(
                                bytes('BAND_NUMBER = %d\n' % (i_band + 1),
                                      'utf-8'))
                        else:
                            fp.write(b'FIELD_NAME = EV_500_RefSB|\n')
                            fp.write(
                                bytes('BAND_NUMBER = %d\n' % (i_band - 1),
                                      'utf-8'))
                        fp.write(
                            bytes('OUTPUT_PIXEL_SIZE_X = %f\n' % pixel_x,
                                  'utf-8'))
                        fp.write(
                            bytes('OUTPUT_PIXEL_SIZE_Y = %f\n' % pixel_y,
                                  'utf-8'))
                        fp.write(
                            bytes(
                                'SPATIAL_SUBSET_UL_CORNER = ( %f %f )\n' %
                                (lat_max, lon_min), 'utf-8'))
                        fp.write(
                            bytes(
                                'SPATIAL_SUBSET_LR_CORNER = ( %f %f )\n' %
                                (lat_min, lon_max), 'utf-8'))
                        fp.write(b'OUTPUT_PROJECTION_TYPE = GEO\n')
                        fp.write(
                            bytes('OUTPUT_FILENAME = %s\n' % out_file,
                                  'utf-8'))
                        fp.write(b'OUTPUT_TYPE = GEO\n')
                        fp.write(b'END\n\n')
                    swtif_bin = os.path.join(global_config['MRTBINDIR'],
                                             'swtif')
                    os.system('cd %s && %s -P HegSwath.prm > heg.log' %
                              (run_path, swtif_bin))
        else:
            print('[Error] 未获得全部的所需信息')
            return (0)
        # 裁切与合并
        if 'all' in band_need:
            raster_file = ifile.replace('.hdf', '_reproj_B1.tif')
        else:
            raster_file = ifile.replace('.hdf',
                                        '_reproj_' + band_need[0] + '.tif')
        if not (os.path.exists(raster_file)):
            return ('')
        out_file = raster_file.replace('.tif', '_cut.tif')
        cut_range = img_cut.main(raster_file,
                                 shp_file=shp_file,
                                 out_file=out_file)
        xsize = cut_range[1] - cut_range[0]
        ysize = cut_range[3] - cut_range[2]
        # 影像裁切
        print('裁切...')
        for i_band in range(nbands):
            modis_band = 'B%d' % (i_band + 1)
            if (modis_band in band_need) or ('all' in band_need):
                raster_file = ifile.replace('.hdf',
                                            '_reproj_' + modis_band + '.tif')
                out_file = raster_file.replace('.tif', '_cut.tif')
                img_cut.main(raster_file, sub_lim=cut_range, out_file=out_file)
        # 辐射定标
        print('辐射定标...')
        raster = gdal.Open(raster_file)
        date_str = '%d%02d%02d%02d%02d%02d' % (year, month, day, hour + 8,
                                               minute, 0)
        nrow = os.path.split(ifile)[1].split('.')[3]
        if satellite == 'AQUA':
            out_name = 'AQUA_MODIS_500_L2_%s_%s_00.tif' % (date_str, nrow)
        elif satellite == 'TERRA':
            out_name = 'TERRA_MODIS_500_L2_%s_%s_00.tif' % (date_str, nrow)
        else:
            print('[Error] 无法识别卫星类型')
            return (0)
        if os.path.exists(path_out_radi):
            raster_fn_out_radi = os.path.join(path_out_radi, out_name)
            driver = gdal.GetDriverByName('GTiff')
            target_ds = driver.Create(raster_fn_out_radi, xsize, ysize, nbands,
                                      gdal.GDT_Int16)
            target_ds.SetGeoTransform(raster.GetGeoTransform())
            target_ds.SetProjection(raster.GetProjectionRef())
            nan_mask = None
            for i_band in range(nbands):
                modis_band = 'B%d' % (i_band + 1)
                if (modis_band in band_need) or ('all' in band_need):
                    raster_file = ifile.replace(
                        '.hdf', '_reproj_' + modis_band + '.tif')
                    out_file = raster_file.replace('.tif', '_cut.tif')
                    raster_cut = gdal.Open(out_file)
                    raster_data = (
                        raster_cut.GetRasterBand(1).ReadAsArray()).astype(
                            np.int)
                    if nan_mask is None:
                        nan_mask = raster_data > 65530
                    raster_data[nan_mask] = -9999
                    target_ds.GetRasterBand(i_band + 1).WriteArray(raster_data)
                    target_ds.GetRasterBand(i_band + 1).SetNoDataValue(-9999)
                    raster_cut = None
            target_ds = None
        else:
            print('[Warning] without radiation-correction output!')
        # 大气校正
        print('大气校正 ...')
        nan_mask = None
        driver = gdal.GetDriverByName('GTiff')
        raster_fn_out_6s = os.path.join(path_out_6s, out_name)
        target_ds = driver.Create(raster_fn_out_6s, xsize, ysize, nbands,
                                  gdal.GDT_Int16)
        target_ds.SetGeoTransform(raster.GetGeoTransform())
        target_ds.SetProjection(raster.GetProjectionRef())
        for i_band in range(nbands):
            modis_band = 'B%d' % (i_band + 1)
            if (modis_band in band_need) or ('all' in band_need):
                raster_file = ifile.replace('.hdf',
                                            '_reproj_' + modis_band + '.tif')
                out_file = raster_file.replace('.tif', '_cut.tif')
                raster_cut = gdal.Open(out_file)
                raster_data = raster_cut.GetRasterBand(1).ReadAsArray()
                if nan_mask is None:
                    nan_mask = raster_data > 65530
                raster_data = raster_data.astype(float)
                raster_data = scales[i_band] * (raster_data - offsets[i_band])
                mtl_coef = {
                    'altitude': altitude,
                    'visibility': visibility,
                    'aero_type': aerotype,
                    'location': center_lonlat,
                    'month': month,
                    'day': day,
                    'solz': solz,
                    'sola': sola,
                    'salz': salz,
                    'sala': sala
                }
                wave_index = i_band + 42
                data_tmp = (arms_corr(raster_data, mtl_coef, wave_index) *
                            10000).astype(np.int)
                data_tmp[nan_mask] = -9999
                target_ds.GetRasterBand(i_band + 1).WriteArray(data_tmp)
                target_ds.GetRasterBand(i_band + 1).SetNoDataValue(-9999)
                raster_cut = None
        target_ds = None
        raster = None
        # 删除过程文件
        for item in modis_band_list:
            file_name = ifile.replace('.hdf', '_reproj_' + item + '.tif')
            file_name_met = file_name.replace('.tif', '.tif.met')
            file_name_cut = ifile.replace('.hdf',
                                          '_reproj_' + item + '_cut.tif')
            if os.path.exists(file_name) and os.path.exists(file_name_met):
                os.remove(file_name)
                os.remove(file_name_met)
            if os.path.exists(file_name_cut):
                os.remove(file_name_cut)
        log_list = [
            'heg.log', 'swtif.log', 'HegSwath.prm', 'HegHdr.hdr', 'hegtool.log'
        ]
        for item in log_list:
            file_name = os.path.join(run_path, item)
            if os.path.exists(file_name):
                os.remove(file_name)
        return (raster_fn_out_6s)
예제 #3
0
def main0(file_ref,
          file_info,
          subrange,
          aerotype=1,
          altitude=0.01,
          visibility=15,
          path_out=None):
    '''
    @description: 主程序
    @file_ref {str}: DN数据文件
    @file_info {str}: 影像信息数据文件(包含经纬度、传感器方位信息等)
    @subrange {lon_min, lon_max, lat_min, lat_max}
    @return: None
    '''
    # 文件信息获取
    file_sd = SD(file_info)
    obj = file_sd.select('Longitude')
    lon = obj.get()
    obj = file_sd.select('Latitude')
    lat = obj.get()
    # 构造目标经纬度
    if subrange is None:
        griddata_key = False  # 不使用griddata插值(考虑效率)
        lon_min = np.min(lon)
        lon_max = np.max(lon)
        lat_min = np.min(lat)
        lat_max = np.max(lat)
        subrange = [lon_min, lon_max, lat_min, lat_max]
        xsize = lon.shape[1]
        ysize = lon.shape[0]
        xstep = (lon_max - lon_min) / xsize
        ystep = (lat_min - lat_max) / ysize
        x1d = np.linspace(subrange[0], subrange[1], xsize)
        y1d = np.linspace(subrange[3], subrange[2], ysize)
        [xx, yy] = np.meshgrid(x1d, y1d)
        cut_index = [0, ysize, 0, xsize]
        xsize0 = xsize
        ysize0 = ysize
        # 重构
        lon_1d = np.reshape(lon, lon.shape[0] * lon.shape[1])
        lat_1d = np.reshape(lat, lat.shape[0] * lat.shape[1])
        lonlat = np.vstack(([lon_1d], [lat_1d])).T
        lon_1d = None
        lat_1d = None
    else:
        griddata_key = True
        xstep = 0.375 / 111
        ystep = -0.375 / 111
        xsize = (subrange[1] - subrange[0]) * 111 / 0.375
        ysize = (subrange[3] - subrange[2]) * 111 / 0.375
        x1d = np.linspace(subrange[0], subrange[1], xsize)
        y1d = np.linspace(subrange[3], subrange[2], ysize)
        [xx, yy] = np.meshgrid(x1d, y1d)
        cut_index = cut_data(subrange, lon, lat)
        xsize0 = cut_index[3] - cut_index[2]  # 裁切后的大小
        ysize0 = cut_index[1] - cut_index[0]
        lon_cut = lon[cut_index[0]:cut_index[1], cut_index[2]:cut_index[3]]
        lat_cut = lat[cut_index[0]:cut_index[1], cut_index[2]:cut_index[3]]
        lon_1d = np.reshape(lon_cut, lon_cut.shape[0] * lon_cut.shape[1])
        lat_1d = np.reshape(lat_cut, lat_cut.shape[0] * lat_cut.shape[1])
        lonlat = np.vstack(([lon_1d], [lat_1d])).T
        lon_cut = None
        lat_cut = None
        lon_1d = None
        lat_1d = None
    # 卫星方位角
    obj = file_sd.select('SatelliteAzimuthAngle')
    sala_all = obj.get()
    sala_cut = sala_all[cut_index[0]:cut_index[1], cut_index[2]:cut_index[3]]
    sala = np.mean(sala_cut)
    if sala < 0:
        sala = sala + 360
    sala_all = None
    sala_cut = None
    # 卫星天顶角
    obj = file_sd.select('SatelliteZenithAngle')
    salz_all = obj.get()
    salz_cut = salz_all[cut_index[0]:cut_index[1], cut_index[2]:cut_index[3]]
    salz = np.mean(salz_cut)
    salz_all = None
    salz_cut = None
    # 太阳天顶角
    obj = file_sd.select('SolarZenithAngle')
    solz_all = obj.get()
    solz_cut = solz_all[cut_index[0]:cut_index[1], cut_index[2]:cut_index[3]]
    solz = np.mean(solz_cut)
    solz_all = None
    solz_cut = None
    # 太阳方位角
    obj = file_sd.select('SolarAzimuthAngle')
    sola_all = obj.get()
    sola_cut = sola_all[cut_index[0]:cut_index[1], cut_index[2]:cut_index[3]]
    sola = np.mean(sola_cut)
    sola_all = None
    sola_cut = None
    file_sd.end()
    file_sd = SD(file_ref)
    atms_corr_resample = np.zeros((xx.shape[0], xx.shape[1], 4))
    center_lonlat = [(subrange[0] + subrange[1]) / 2,
                     (subrange[2] + subrange[3]) / 2]
    for i in range(4):
        obj_name = 'Radiance_I' + str(i + 1)
        obj = file_sd.select(obj_name)
        data = obj.get()
        data_cut = data[cut_index[0]:cut_index[1], cut_index[2]:cut_index[3]]
        info = obj.attributes()
        scale = info['Scale']
        offset = info['Offset']
        radi_cali = data_cut * scale + offset
        date_str = os.path.split(file_ref)[1].split('.')[1]
        date_str = date_teanslator.jd_to_cale(date_str[1:])
        month = int(date_str.split('.')[1])
        day = int(date_str.split('.')[2])
        mtl_coef = {
            'altitude': altitude,
            'visibility': visibility,
            'aero_type': aerotype,
            'location': center_lonlat,
            'month': month,
            'day': day,
            'solz': solz,
            'sola': sola,
            'salz': salz,
            'sala': sala
        }
        atms_corr = arms_corr(radi_cali, mtl_coef, i + 161)
        # 重采样
        print('I%d 大气校正 ...' % (i + 1))
        atms_corr_1d = np.reshape(atms_corr, xsize0 * ysize0)
        print('I%d 重采样 ...' % (i + 1))
        if griddata_key:
            atms_corr_resample[:, :, i] = griddata(lonlat,
                                                   atms_corr_1d, (xx, yy),
                                                   method='nearest')
        else:
            # 采用最邻近填充法,效率较高,但有部分损失
            data_resize = np.zeros(xx.shape) + np.nan
            x_index = np.round((lon - lon_min) / xstep)
            y_index = ysize - 1 - np.round((lat - lat_min) / abs(ystep))
            usefall_key = np.logical_and(x_index < xsize, x_index >= 0)
            usefall_key = np.logical_and(usefall_key, y_index < ysize)
            usefall_key = np.logical_and(usefall_key, y_index >= 0)
            x_index_1d = x_index[usefall_key]
            y_index_1d = y_index[usefall_key]
            data_index_1d = atms_corr[usefall_key]
            for j in range(len(y_index[usefall_key])):
                data_resize[int(y_index_1d[j]),
                            int(x_index_1d[j])] = data_index_1d[j]
    driver = gdal.GetDriverByName('GTiff')
    xsize = np.shape(atms_corr_resample)[1]
    ysize = np.shape(atms_corr_resample)[0]
    nbands = 4
    if '.hdf' in file_ref:
        # 输出文件名
        date_str = os.path.split(file_ref)[1].split('.')[1]
        time_str = os.path.split(file_ref)[1].split('.')[2]
        date_str = date_teanslator.jd_to_cale(date_str[1:])
        year = int(date_str.split('.')[0])
        month = int(date_str.split('.')[1])
        day = int(date_str.split('.')[2])
        hour = int(time_str[0:2])
        minute = int(time_str[2:])
        date_str = '%d%02d%02d%02d%02d%02d' % (year, month, day, hour + 8,
                                               minute, 0)
        nrow = os.path.split(file_ref)[1].split('.')[3]
        out_name = 'NPP_VIIRS_375_L2_%s_%s_00.tif' % (date_str, nrow)
        raster_fn_out = os.path.join(path_out, out_name)
    else:
        print('无法识别的文件类型: %s' % os.path.split(file_ref)[1])
        return (0)
    target_ds = driver.Create(raster_fn_out, xsize, ysize, nbands,
                              gdal.GDT_UInt16)
    raster_srs = osr.SpatialReference()
    raster_srs.ImportFromEPSG(4326)
    geo_trans = (subrange[0], xstep, 0, subrange[3], 0, ystep)
    target_ds.SetGeoTransform(geo_trans)
    target_ds.SetProjection(raster_srs.ExportToWkt())
    for i in range(nbands):
        data_tmp = atms_corr_resample[:, :, i]
        mask = np.logical_or(data_tmp >= 65530, np.isnan(data_tmp))
        data_tmp = (data_tmp * 10000).astype(np.int)
        data_tmp[mask] = 65530
        target_ds.GetRasterBand(i + 1).WriteArray(data_tmp)
        band = target_ds.GetRasterBand(1 + 1)
        band.SetNoDataValue(65530)
    target_ds = None
    return (raster_fn_out)
예제 #4
0
def main(file_ref,
         file_info,
         subrange,
         aerotype=1,
         altitude=0.01,
         visibility=15,
         path_out=None):
    '''
    @description: 主程序
    @file_ref {str}: DN数据文件
    @file_info {str}: 影像信息数据文件(包含经纬度、传感器方位信息等)
    @subrange {lon_min, lon_max, lat_min, lat_max}
    @return: None
    '''
    nbands = 4
    mask_value = 65533
    # hdf转geotif
    print('reconstruction...')
    hdf_merge = file_ref.replace('.hdf', '_merge.hdf')
    if os.path.exists(hdf_merge):
        os.system('rm %s' % hdf_merge)
    file_sd_out = SD(hdf_merge, SDC.CREATE | SDC.WRITE)
    file_sd1 = SD(file_ref)
    file_sd2 = SD(file_info)
    obj = file_sd2.select('Longitude')
    lon = obj.get()
    obj = file_sd2.select('Latitude')
    lat = obj.get()
    fields = ['Radiance_I%d' % (i + 1) for i in range(nbands)]
    size = None
    for field in fields:
        obji = file_sd1.select(field)
        if size is None:
            size = obji[:].shape
        objo = file_sd_out.create(field, SDC.UINT16, size)
        objo.set(obji[:])
    obji = file_sd2.select('Longitude')
    objo = file_sd_out.create('Longitude', SDC.FLOAT32, size)
    objo.set(obji[:])
    obji = file_sd2.select('Latitude')
    objo = file_sd_out.create('Latitude', SDC.FLOAT32, size)
    objo.set(obji[:])
    obji.endaccess()
    objo.endaccess()
    file_sd_out.end()
    for i in range(nbands):
        ofile_tif = file_ref.replace('.hdf', '_band%s_reproj.tif' % str(i + 1))
        cmd = '%s -geoloc -t_srs EPSG:4326 -srcnodata %s HDF4_SDS:UNKNOWN:"%s":%s %s' % (
            global_config['path_gdalwarp'], mask_value, hdf_merge, str(i),
            ofile_tif)
        os.system(cmd)
    # 卫星方位信息
    cut_index = cut_data(subrange, lon, lat)
    # 太阳天顶角
    obj = file_sd2.select('SolarZenithAngle')
    data = obj.get()
    data_cut = data[cut_index[0]:cut_index[1], cut_index[2]:cut_index[3]]
    solz = np.mean(data_cut)
    # 太阳方位角
    obj = file_sd2.select('SolarAzimuthAngle')
    data = obj.get()
    data_cut = data[cut_index[0]:cut_index[1], cut_index[2]:cut_index[3]]
    sola = np.mean(data_cut)
    # 卫星天顶角
    obj = file_sd2.select('SatelliteZenithAngle')
    data = obj.get()
    data_cut = data[cut_index[0]:cut_index[1], cut_index[2]:cut_index[3]]
    salz = np.mean(data_cut)
    # 卫星方位角
    obj = file_sd2.select('SatelliteAzimuthAngle')
    data = obj.get()
    data_cut = data[cut_index[0]:cut_index[1], cut_index[2]:cut_index[3]]
    sala = np.mean(data_cut)
    center_lonlat = [(subrange[0] + subrange[1]) / 2,
                     (subrange[2] + subrange[3]) / 2]
    raster = gdal.Open(file_ref.replace('.hdf', '_band1_reproj.tif'))
    xsize = raster.RasterXSize
    ysize = raster.RasterYSize
    geo_trans = raster.GetGeoTransform()
    proj_ref = raster.GetProjectionRef()
    # 计算裁切范围
    target_lon_min = 116.28
    target_lon_max = 125.0
    target_lat_min = 30.0
    target_lat_max = 37.83
    colm_s = int(round((target_lon_min - geo_trans[0]) / geo_trans[1]))
    colm_e = int(round((target_lon_max - geo_trans[0]) / geo_trans[1]))
    line_s = int(round((target_lat_max - geo_trans[3]) / geo_trans[5]))
    line_e = int(round((target_lat_min - geo_trans[3]) / geo_trans[5]))
    if colm_s < 0:
        colm_s = 0
    if line_s < 0:
        line_s = 0
    if colm_e >= xsize:
        colm_e = xsize - 1
    if line_e >= ysize:
        line_e = ysize - 1
    x_1d = np.array([geo_trans[0] + i * geo_trans[1] for i in range(xsize)])
    y_1d = np.array([geo_trans[3] + i * geo_trans[5] for i in range(ysize)])
    xx, yy = np.meshgrid(x_1d, y_1d)
    xx_sub = xx[line_s:line_e, colm_s:colm_e]
    yy_sub = yy[line_s:line_e, colm_s:colm_e]
    # 文件保存所需信息
    date_str = os.path.split(file_ref)[1].split('.')[1]
    time_str = os.path.split(file_ref)[1].split('.')[2]
    date_str = date_teanslator.jd_to_cale(date_str[1:])
    year = int(date_str.split('.')[0])
    month = int(date_str.split('.')[1])
    day = int(date_str.split('.')[2])
    hour = int(time_str[0:2])
    minute = int(time_str[2:])
    date_str = '%d%02d%02d%02d%02d%02d' % (year, month, day, hour + 8, minute,
                                           0)
    nrow = os.path.split(file_ref)[1].split('.')[3]
    out_name = 'NPP_VIIRS_375_L2_%s_%s_00.tif' % (date_str, nrow)
    raster_fn_out = os.path.join(path_out, out_name)
    driver = gdal.GetDriverByName('GTiff')
    target_ds = driver.Create(raster_fn_out, xsize, ysize, nbands,
                              gdal.GDT_UInt16)
    target_ds.SetGeoTransform(geo_trans)
    target_ds.SetProjection(proj_ref)
    # 大气校正
    for i in range(nbands):
        obj_name = 'Radiance_I' + str(i + 1)
        obj = file_sd1.select(obj_name)
        raster = gdal.Open(
            file_ref.replace('.hdf', '_band%s_reproj.tif' % str(i + 1)))
        data = raster.GetRasterBand(1).ReadAsArray()
        print('重采样:Band %s' % (i + 1))
        data_sub = data[line_s:line_e, colm_s:colm_e]
        blank_key = data_sub == mask_value
        # OpenCV形态学处理
        blank_key[:, 0] = 0
        blank_key[:, -1] = 0
        blank_key[0, :] = 0
        blank_key[-1, :] = 0
        labels_struct = cv2.connectedComponentsWithStats(blank_key.astype(
            np.uint8),
                                                         connectivity=4)
        for i_label in range(1, labels_struct[0]):
            if labels_struct[2][i_label][4] > 1e5:
                blank_key[labels_struct[1] == i_label] = 0
        lon_blank = xx_sub[blank_key]
        lat_blank = yy_sub[blank_key]
        valid_key = np.logical_not(blank_key)
        lon_valid = xx_sub[valid_key]
        lat_valid = yy_sub[valid_key]
        lonlat = np.vstack((lon_valid, lat_valid)).T
        data_valid = data_sub[valid_key]
        data_blank = griddata(lonlat,
                              data_valid, (lon_blank, lat_blank),
                              method='nearest')
        data_sub[blank_key] = data_blank
        data[line_s:line_e, colm_s:colm_e] = data_sub
        mask = data == mask_value
        print('I%d 辐射定标和大气校正 ...' % (i + 1))
        info = obj.attributes()
        scale = info['Scale']
        offset = info['Offset']
        radi_cali = data.astype(float) * scale + offset
        date_str = os.path.split(file_ref)[1].split('.')[1]
        date_str = date_teanslator.jd_to_cale(date_str[1:])
        month = int(date_str.split('.')[1])
        day = int(date_str.split('.')[2])
        mtl_coef = {
            'altitude': altitude,
            'visibility': visibility,
            'aero_type': aerotype,
            'location': center_lonlat,
            'month': month,
            'day': day,
            'solz': solz,
            'sola': sola,
            'salz': salz,
            'sala': sala
        }
        atms_corr = arms_corr(radi_cali, mtl_coef, i + 161)
        # save
        data_tmp = (atms_corr * 10000).astype(np.int)
        data_tmp[mask] = mask_value
        target_ds.GetRasterBand(i + 1).WriteArray(data_tmp)
        band = target_ds.GetRasterBand(i + 1)
        band.SetNoDataValue(mask_value)
    target_ds = None
    file_sd1.end()
    file_sd2.end()
    # 删除过程文件
    os.system('rm %s' % hdf_merge)
    for i in range(nbands):
        os.system(
            'rm %s' %
            (file_ref.replace('.hdf', '_band%s_reproj.tif' % str(i + 1))))
    return (raster_fn_out)
예제 #5
0
def main(ifile,
         shp_file,
         center_lonlat,
         salz,
         sala,
         cut_range=None,
         aerotype=1,
         altitude=0.01,
         visibility=15,
         band_need=['all'],
         path_out=None):
    '''
    @description: 
    @ifile {str}: L1B swath文件 
    @shp_file {str}: 研究区域矢量文件
    @center_lonlat {[lon, lat]}: 中心经纬度
    @aerotype {int}: 气溶胶类型(默认大陆型)
    @altitude {float}: 海拔(km)
    @visibility {float}: 能见度(km)
    @return: 
    '''
    # 设置环境变量
    os.environ['MRTDATADIR'] = global_config['MRTDATADIR']
    os.environ['PGSHOME'] = global_config['PGSHOME']
    os.environ['MRTBINDIR'] = global_config['MRTBINDIR']
    run_path = os.path.split(ifile)[0]
    # step 1: 获取hdf信息
    nbands = 16
    SD_file = SD(ifile)
    scales = []
    offsets = []
    for i_band in range(nbands):
        if i_band + 1 == 9:
            sds_obj = SD_file.select('Reflectance_M%d' % (i_band + 1))
        elif i_band + 1 >= 12:
            sds_obj = SD_file.select('BrightnessTemperature_M%d' %
                                     (i_band + 1))
        else:
            sds_obj = SD_file.select('Radiance_M%d' % (i_band + 1))
        sds_info = sds_obj.attributes()
        if 'Scale' in sds_info:
            scales.append(sds_info['Scale'])
        else:
            scales.append(1)
        if 'Offset' in sds_info:
            offsets.append(sds_info['Offset'])
        else:
            offsets.append(0)
    SD_file.end()
    date_str = os.path.split(ifile)[1].split('.')[1]
    time_str = os.path.split(ifile)[1].split('.')[2]
    date_str = date_teanslator.jd_to_cale(date_str[1:])
    year = int(date_str.split('.')[0])
    month = int(date_str.split('.')[1])
    day = int(date_str.split('.')[2])
    hour = int(time_str[0:2])
    minute = int(time_str[2:])
    date = '%d/%d/%d %d:%d:00' % (year, month, day, hour, minute)
    sola_position = calc_sola_position.main(center_lonlat[0], center_lonlat[1],
                                            date)
    solz = sola_position[0]
    sola = sola_position[1]
    heg_bin = os.path.join(global_config['MRTBINDIR'], 'hegtool')
    os.system('cd %s && %s -h %s > heg.log' % (run_path, heg_bin, ifile))
    info_file = os.path.join(run_path, 'HegHdr.hdr')
    if not (os.path.exists(info_file)):
        print('获取文件信息出错:%s' % ifile)
        return (0)
    else:
        lat_min = None
        lat_max = None
        lon_min = None
        lon_max = None
        pixel_x = None
        pixel_y = None
        with open(info_file, 'r') as fp:
            lines = fp.readlines()
            for line in lines:
                if 'SWATH_LAT_MIN' in line:
                    lat_min = float(re.findall(r'[\d.]+', line)[0])
                elif 'SWATH_LAT_MAX' in line:
                    lat_max = float(re.findall(r'[\d.]+', line)[0])
                elif 'SWATH_LON_MIN' in line:
                    lon_min = float(re.findall(r'[\d.]+', line)[0])
                elif 'SWATH_LON_MAX' in line:
                    lon_max = float(re.findall(r'[\d.]+', line)[0])
                elif 'SWATH_X_PIXEL_RES_DEGREES' in line:
                    pixel_x = float(re.findall(r'[\d.]+', line)[0])
                elif 'SWATH_Y_PIXEL_RES_DEGREES' in line:
                    pixel_y = float(re.findall(r'[\d.]+', line)[0])
            if not (cut_range is None):
                lat_min = cut_range[2]
                lat_max = cut_range[3]
                lon_min = cut_range[0]
                lon_max = cut_range[1]
        if lat_min and lat_max and lon_min and lon_max and pixel_x and pixel_y:
            prm_file = os.path.join(run_path, 'HegSwath.prm')
            # 反射波段值重投影
            for i_band in range(nbands):
                viirs_band = 'M%d' % (i_band + 1)
                if (viirs_band in band_need) or ('all' in band_need):
                    print('reprojection %s ...' % viirs_band)
                    out_file = ifile.replace('.hdf',
                                             '_reproj_' + viirs_band + '.tif')
                    if '.hdf' in ifile:
                        if os.path.exists(out_file):
                            os.remove(out_file)
                    else:
                        print('无法识别的文件格式:%s' % os.path.split(ifile)[1])
                    with open(prm_file, 'wb') as fp:
                        fp.write(b'\nNUM_RUNS = 1\n\n')
                        fp.write(b'BEGIN\n')
                        fp.write(
                            bytes('INPUT_FILENAME = %s\n' % ifile, 'utf-8'))
                        fp.write(b'OBJECT_NAME = VIIRS_EV_750M_SDR\n')
                        if viirs_band == 'M9':
                            fp.write(
                                bytes(
                                    'FIELD_NAME = Reflectance_M%d|\n' %
                                    (i_band + 1), 'utf-8'))
                        elif viirs_band == 'M15' or viirs_band == 'M16':
                            fp.write(
                                bytes(
                                    'FIELD_NAME = BrightnessTemperature_M%d|\n'
                                    % (i_band + 1), 'utf-8'))
                        else:
                            fp.write(
                                bytes(
                                    'FIELD_NAME = Radiance_M%d|\n' %
                                    (i_band + 1), 'utf-8'))
                        fp.write(b'BAND_NUMBER = 1\n')
                        fp.write(
                            bytes('OUTPUT_PIXEL_SIZE_X = %f\n' % pixel_x,
                                  'utf-8'))
                        fp.write(
                            bytes('OUTPUT_PIXEL_SIZE_Y = %f\n' % pixel_y,
                                  'utf-8'))
                        fp.write(
                            bytes(
                                'SPATIAL_SUBSET_UL_CORNER = ( %f %f )\n' %
                                (lat_max, lon_min), 'utf-8'))
                        fp.write(
                            bytes(
                                'SPATIAL_SUBSET_LR_CORNER = ( %f %f )\n' %
                                (lat_min, lon_max), 'utf-8'))
                        fp.write(b'OUTPUT_PROJECTION_TYPE = GEO\n')
                        fp.write(
                            bytes('OUTPUT_FILENAME = %s\n' % out_file,
                                  'utf-8'))
                        fp.write(b'OUTPUT_TYPE = GEO\n')
                        fp.write(b'END\n\n')
                    swtif_bin = os.path.join(global_config['MRTBINDIR'],
                                             'swtif')
                    os.system('cd %s && %s -P HegSwath.prm > heg.log' %
                              (run_path, swtif_bin))
        else:
            print('[Error] 未获得全部的所需信息')
            return (0)
        # 裁切与合并
        if 'all' in band_need:
            raster_file = ifile.replace('.hdf', '_reproj_M1.tif')
        else:
            raster_file = ifile.replace('.hdf',
                                        '_reproj_' + band_need[0] + '.tif')
        out_file = raster_file.replace('.tif', '_cut.tif')
        cut_range = img_cut.main(raster_file,
                                 shp_file=shp_file,
                                 out_file=out_file)
        xsize = cut_range[1] - cut_range[0]
        ysize = cut_range[3] - cut_range[2]
        data_join = np.zeros([ysize, xsize, nbands])
        print('大气校正 ...')
        for i_band in range(nbands):
            viirs_band = 'M%d' % (i_band + 1)
            if (viirs_band in band_need) or ('all' in band_need):
                raster_file = ifile.replace('.hdf',
                                            '_reproj_' + viirs_band + '.tif')
                out_file = raster_file.replace('.tif', '_cut.tif')
                img_cut.main(raster_file, sub_lim=cut_range, out_file=out_file)
                # 辐射定标和大气校正
                raster = gdal.Open(out_file)
                raster_data = raster.GetRasterBand(1).ReadAsArray()
                raster_data = raster_data.astype(float)
                mask = raster_data >= 65530
                raster_data = scales[i_band] * raster_data + offsets[i_band]
                mtl_coef = {
                    'altitude': altitude,
                    'visibility': visibility,
                    'aero_type': aerotype,
                    'location': center_lonlat,
                    'month': month,
                    'day': day,
                    'solz': solz,
                    'sola': sola,
                    'salz': salz,
                    'sala': sala
                }
                if i_band <= 11:
                    wave_index = i_band + 149
                else:
                    wave_index = None
                if not (wave_index is None):
                    if viirs_band == 'M9':
                        tmp = raster_data
                    else:
                        tmp = arms_corr(raster_data, mtl_coef, wave_index)
                    # tmp = raster_data
                    tmp[mask] = 65535
                    data_join[:, :, i_band] = tmp
                else:
                    tmp = raster_data
                    tmp[mask] = 65535
                    data_join[:, :, i_band] = tmp
        driver = gdal.GetDriverByName('GTiff')
        if 'all' in band_need:
            raster_file = ifile.replace('.hdf', '_reproj_M1_cut.tif')
        else:
            raster_file = ifile.replace('.hdf',
                                        '_reproj_' + band_need[0] + '_cut.tif')
        raster = gdal.Open(raster_file)
        # 输出文件名
        date_str = '%d%02d%02d%02d%02d%02d' % (year, month, day, hour + 8,
                                               minute, 0)
        nrow = os.path.split(ifile)[1].split('.')[3]
        out_name = 'NPP_VIIRS_750_L2_%s_%s_00.tif' % (date_str, nrow)
        if path_out is None:
            raster_fn_out = os.path.join(os.path.split(ifile)[0], out_name)
        else:
            raster_fn_out = os.path.join(path_out, out_name)
        target_ds = driver.Create(raster_fn_out, xsize, ysize, nbands,
                                  gdal.GDT_UInt16)
        target_ds.SetGeoTransform(raster.GetGeoTransform())
        target_ds.SetProjection(raster.GetProjectionRef())
        for i in range(nbands):
            if i < 11:
                target_ds.GetRasterBand(i + 1).WriteArray(
                    (data_join[:, :, i] * 10000).astype(np.int))
            else:
                target_ds.GetRasterBand(i + 1).WriteArray(
                    (data_join[:, :, i] * 100).astype(np.int))
            band = target_ds.GetRasterBand(i + 1)
            band.SetNoDataValue(65535)
        target_ds = None
        raster = None
        # 删除过程文件
        for item in viirs_band_list:
            file_name = ifile.replace('.hdf', '_reproj_' + item + '.tif')
            file_name_met = file_name.replace('.tif', '.tif.met')
            file_name_cut = ifile.replace('.hdf',
                                          '_reproj_' + item + '_cut.tif')
            if os.path.exists(file_name) and os.path.exists(file_name_met):
                os.remove(file_name)
                os.remove(file_name_met)
            if os.path.exists(file_name_cut):
                os.remove(file_name_cut)
        log_list = [
            'heg.log', 'swtif.log', 'HegSwath.prm', 'HegHdr.hdr', 'hegtool.log'
        ]
        for item in log_list:
            file_name = os.path.join(run_path, item)
            if os.path.exists(file_name):
                os.remove(file_name)
        return (raster_fn_out)