def main():
    '''
    Perform in-memory trait estimation.
    '''

    parser = argparse.ArgumentParser(
        description="In memory trait mapping tool.")
    parser.add_argument("-img",
                        help="Input image pathname",
                        required=True,
                        type=str)
    parser.add_argument("--obs",
                        help="Input observables pathname",
                        required=False,
                        type=str)
    parser.add_argument("--out",
                        help="Output full corrected image",
                        required=False,
                        type=str)
    parser.add_argument("-od",
                        help="Output directory for all resulting products",
                        required=True,
                        type=str)
    parser.add_argument("--brdf",
                        help="Perform BRDF correction",
                        type=str,
                        default='')
    parser.add_argument("--topo",
                        help="Perform topographic correction",
                        type=str,
                        default='')
    parser.add_argument("--mask",
                        help="Image mask type to use",
                        action='store_true')
    parser.add_argument("--mask_threshold",
                        help="Mask threshold value",
                        nargs='*',
                        type=float)
    parser.add_argument("--rgbim",
                        help="Export RGBI +Mask image.",
                        action='store_true')
    #parser.add_argument("-coeffs", help="Trait coefficients directory", required=True, type = str)
    parser.add_argument("-coeffs",
                        help="Trait coefficients directory",
                        required=False,
                        type=str)
    args = parser.parse_args()

    traits = glob.glob("%s/*.json" % args.coeffs)

    #Load data objects memory
    if args.img.endswith(".h5"):
        hyObj = ht.openHDF(args.img, load_obs=True)
    else:
        hyObj = ht.openENVI(args.img)
    if (len(args.topo) != 0) | (len(args.brdf) != 0):
        hyObj.load_obs(args.obs)
    if not args.od.endswith("/"):
        args.od += "/"
    hyObj.create_bad_bands([[300, 400], [1330, 1430], [1800, 1960],
                            [2450, 2600]])

    # no data  / ignored values varies by product
    hyObj.no_data = -0.9999

    hyObj.load_data()

    # Generate mask
    if args.mask:
        ir = hyObj.get_wave(850)
        red = hyObj.get_wave(665)
        ndvi = (ir - red) / (ir + red)
        #mask = (ndvi > args.mask_threshold) & (ir != hyObj.no_data)
        #mask = (ndvi > 0.05) & (ir != hyObj.no_data)
        hyObj.mask = (ndvi > 0.05) & (ir != hyObj.no_data)
        del ir, red  #,ndvi
    else:
        hyObj.mask = np.ones((hyObj.lines, hyObj.columns)).astype(bool)
        print("Warning no mask specified, results may be unreliable!")

    # Generate cosine i and c1 image for topographic correction
    if len(args.topo) != 0:
        with open(args.topo) as json_file:
            topo_coeffs = json.load(json_file)

        topo_coeffs['c'] = np.array(topo_coeffs['c'])
        cos_i = calc_cosine_i(hyObj.solar_zn, hyObj.solar_az, hyObj.azimuth,
                              hyObj.slope)
        c1 = np.cos(hyObj.solar_zn)
        c2 = np.cos(hyObj.slope)

        topomask = hyObj.mask & (cos_i > 0.12) & (hyObj.slope > 0.087)

    #total_bin = len(args.mask_threshold)+1
    #brdf_coeffs_List = []

    # Gernerate scattering kernel images for brdf correction
    if len(args.brdf) != 0:

        total_bin = len(args.mask_threshold) + 1
        brdf_coeffs_List = []

        ndvi_thres = [0.05] + args.mask_threshold + [1.0]
        brdfmask = np.ones(
            (total_bin, hyObj.lines, hyObj.columns)).astype(bool)

        for ibin in range(total_bin):

            with open(args.brdf + '_brdf_coeffs_' + str(ibin + 1) +
                      '.json') as json_file:
                brdf_coeffs = json.load(json_file)
                brdf_coeffs['fVol'] = np.array(brdf_coeffs['fVol'])
                brdf_coeffs['fGeo'] = np.array(brdf_coeffs['fGeo'])
                brdf_coeffs['fIso'] = np.array(brdf_coeffs['fIso'])
                brdf_coeffs_List.append(brdf_coeffs)

            brdfmask[ibin, :, :] = hyObj.mask & (ndvi > ndvi_thres[ibin]) & (
                ndvi <= ndvi_thres[ibin + 1])

        k_vol = generate_volume_kernel(hyObj.solar_az,
                                       hyObj.solar_zn,
                                       hyObj.sensor_az,
                                       hyObj.sensor_zn,
                                       ross=brdf_coeffs_List[0]['ross'])
        k_geom = generate_geom_kernel(hyObj.solar_az,
                                      hyObj.solar_zn,
                                      hyObj.sensor_az,
                                      hyObj.sensor_zn,
                                      li=brdf_coeffs_List[0]['li'])
        k_vol_nadir = generate_volume_kernel(hyObj.solar_az,
                                             hyObj.solar_zn,
                                             hyObj.sensor_az,
                                             0,
                                             ross=brdf_coeffs_List[0]['ross'])
        k_geom_nadir = generate_geom_kernel(hyObj.solar_az,
                                            hyObj.solar_zn,
                                            hyObj.sensor_az,
                                            0,
                                            li=brdf_coeffs_List[0]['li'])

    if len(traits) != 0:

        #Cycle through the chunks and apply topo, brdf, vnorm,resampling and trait estimation steps
        print("Calculating values for %s traits....." % len(traits))

        # Cycle through trait models and gernerate resampler
        trait_waves_all = []
        trait_fwhm_all = []

        for i, trait in enumerate(traits):
            with open(trait) as json_file:
                trait_model = json.load(json_file)

            # Check if wavelength units match
            if trait_model['wavelength_units'] == 'micrometers':
                trait_wave_scaler = 10**3
            else:
                trait_wave_scaler = 1

            # Get list of wavelengths to compare against image wavelengths
            if len(trait_model['vector_norm_wavelengths']) == 0:
                trait_waves_all += list(
                    np.array(trait_model['model_wavelengths']) *
                    trait_wave_scaler)
            else:
                trait_waves_all += list(
                    np.array(trait_model['vector_norm_wavelengths']) *
                    trait_wave_scaler)

            trait_fwhm_all += list(
                np.array(trait_model['fwhm']) * trait_wave_scaler)

        # List of all unique pairs of wavelengths and fwhm
        trait_waves_fwhm = list(
            set([x for x in zip(trait_waves_all, trait_fwhm_all)]))
        trait_waves_fwhm.sort(key=lambda x: x[0])

        # Create a single set of resampling coefficients for all wavelength and fwhm combos
        #resampling_coeffs = est_transform_matrix(hyObj.wavelengths[hyObj.bad_bands],[x for (x,y) in trait_waves_fwhm] ,hyObj.fwhm[hyObj.bad_bands],[y for (x,y) in trait_waves_fwhm],1)

        # if wavelengths match, no need to resample
        check_wave_match_result = check_wave_match(
            hyObj, [x for (x, y) in trait_waves_fwhm])
        if (check_wave_match_result['flag']):
            match_flag = True
        else:
            match_flag = False
            resampling_coeffs = est_transform_matrix(
                hyObj.wavelengths[hyObj.bad_bands],
                [x
                 for (x, y) in trait_waves_fwhm], hyObj.fwhm[hyObj.bad_bands],
                [y for (x, y) in trait_waves_fwhm], 1)  #2

    #else:
    #  print('no trait')

    hyObj.wavelengths = hyObj.wavelengths[hyObj.bad_bands]

    pixels_processed = 0
    iterator = hyObj.iterate(by='chunk', chunk_size=(32, hyObj.columns))

    while not iterator.complete:
        chunk = iterator.read_next()
        chunk_nodata_mask = chunk[:, :, 50] == hyObj.no_data  # 50th band
        pixels_processed += chunk.shape[0] * chunk.shape[1]
        #progbar(pixels_processed, hyObj.columns*hyObj.lines, 100)

        # Chunk Array indices
        line_start = iterator.current_line
        line_end = iterator.current_line + chunk.shape[0]
        col_start = iterator.current_column
        col_end = iterator.current_column + chunk.shape[1]

        # Apply TOPO correction
        if len(args.topo) != 0:
            cos_i_chunk = cos_i[line_start:line_end, col_start:col_end]
            c1_chunk = c1[line_start:line_end, col_start:col_end]
            c2_chunk = c2[line_start:line_end, col_start:col_end]
            topomask_chunk = topomask[line_start:line_end, col_start:col_end,
                                      np.newaxis]
            correctionFactor = (c2_chunk[:, :, np.newaxis] +
                                topo_coeffs['c'] / c1_chunk[:, :, np.newaxis]
                                ) / (cos_i_chunk[:, :, np.newaxis] +
                                     topo_coeffs['c'])
            correctionFactor = correctionFactor * topomask_chunk + 1.0 * (
                1 - topomask_chunk)
            chunk = chunk[:, :, hyObj.bad_bands] * correctionFactor
        else:
            chunk = chunk[:, :, hyObj.bad_bands] * 1

        # Apply BRDF correction
        if len(args.brdf) != 0:
            # Get scattering kernel for chunks
            k_vol_nadir_chunk = k_vol_nadir[line_start:line_end,
                                            col_start:col_end]
            k_geom_nadir_chunk = k_geom_nadir[line_start:line_end,
                                              col_start:col_end]
            k_vol_chunk = k_vol[line_start:line_end, col_start:col_end]
            k_geom_chunk = k_geom[line_start:line_end, col_start:col_end]

            #veg_mask = brdfmask[:,line_start:line_end,col_start:col_end]

            n_wavelength = brdf_coeffs_List[0]['fVol'].shape[0]
            new_k_vol = np.zeros(
                (chunk.shape[0], chunk.shape[1], n_wavelength))
            new_k_geom = np.zeros(
                (chunk.shape[0], chunk.shape[1], n_wavelength))
            new_k_iso = np.zeros(
                (chunk.shape[0], chunk.shape[1], n_wavelength))

            #v_msk1 = brdfmask[0,line_start:line_end,col_start:col_end][:,:,np.newaxis]   #| (veg_total == 0)
            #v_msk2 = brdfmask[1,line_start:line_end,col_start:col_end][:,:,np.newaxis]
            #v_msk3 = brdfmask[2,line_start:line_end,col_start:col_end][:,:,np.newaxis]

            for ibin in range(total_bin):

                veg_mask = brdfmask[ibin, line_start:line_end,
                                    col_start:col_end][:, :, np.newaxis]

                new_k_vol += brdf_coeffs_List[ibin]['fVol'] * veg_mask
                new_k_geom += brdf_coeffs_List[ibin]['fGeo'] * veg_mask
                new_k_iso += brdf_coeffs_List[ibin]['fIso'] * veg_mask

            #new_k_vol = brdf_df1.k_vol.values * v_msk1 + brdf_df2.k_vol.values * v_msk2 + brdf_df3.k_vol.values * v_msk3  # + brdf_df4.k_vol.values * v_msk4

            #new_k_geom = brdf_df1.k_geom.values * v_msk1 + brdf_df2.k_geom.values * v_msk2 + brdf_df3.k_geom.values * v_msk3    #+ brdf_df4.k_geom.values * v_msk4

            #new_k_iso = brdf_df1.k_iso.values * v_msk1 + brdf_df2.k_iso.values * v_msk2 + brdf_df3.k_iso.values * v_msk3

            # Apply brdf correction
            # eq 5. Weyermann et al. IEEE-TGARS 2015)
            #brdf = np.einsum('i,jk-> jki', brdf_coeffs['fVol'],k_vol_chunk) + np.einsum('i,jk-> jki', brdf_coeffs['fGeo'],k_geom_chunk)  + brdf_coeffs['fIso']
            #brdf_nadir = np.einsum('i,jk-> jki', brdf_coeffs['fVol'],k_vol_nadir_chunk) + np.einsum('i,jk-> jki', brdf_coeffs['fGeo'],k_geom_nadir_chunk)  +brdf_coeffs['fIso']
            #correctionFactor = brdf_nadir/brdf
            #chunk= chunk* correctionFactor

            brdf = np.einsum(
                'ijk,ij-> ijk', new_k_vol, k_vol_chunk) + np.einsum(
                    'ijk,ij-> ijk', new_k_geom, k_geom_chunk) + new_k_iso
            brdf_nadir = np.einsum(
                'ijk,ij-> ijk', new_k_vol, k_vol_nadir_chunk) + np.einsum(
                    'ijk,ij-> ijk', new_k_geom, k_geom_nadir_chunk) + new_k_iso
            correctionFactor = brdf_nadir / brdf  #*veg_total+(1.0-veg_total)
            correctionFactor[brdf == 0.0] = 1.0
            chunk = chunk * correctionFactor

        #Reassign no data values
        chunk[chunk_nodata_mask, :] = 0

        if len(traits) > 0:
            # Resample chunk
            #chunk_r = np.dot(chunk, resampling_coeffs)
            if match_flag == False:
                chunk_r = np.dot(chunk, resampling_coeffs)
            # subset of chunk
            else:
                chunk_r = chunk[:, :, check_wave_match_result['index']]

        # Export RGBIM image
        if args.rgbim:
            dstFile = args.od + os.path.splitext(os.path.basename(
                args.img))[0] + '_rgbim.tif'
            if line_start + col_start == 0:
                driver = gdal.GetDriverByName("GTIFF")
                #tiff = driver.Create(dstFile,hyObj.columns,hyObj.lines,5,gdal.GDT_Float32)
                tiff = driver.Create(dstFile, hyObj.columns, hyObj.lines, 8,
                                     gdal.GDT_Float32)
                tiff.SetGeoTransform(hyObj.transform)
                tiff.SetProjection(hyObj.projection)
                for band in range(1, 9):
                    tiff.GetRasterBand(band).SetNoDataValue(0)
                tiff.GetRasterBand(8).WriteArray(hyObj.mask)

                del tiff, driver
            # Write rgbi chunk
            rgbi_geotiff = gdal.Open(dstFile, gdal.GA_Update)

            for i, wave in enumerate([480, 560, 660, 850, 976, 1650, 2217],
                                     start=1):
                #print(wave, band)
                band = hyObj.wave_to_band(wave)

                rgbi_geotiff.GetRasterBand(i).WriteArray(
                    chunk[:, :, band], col_start, line_start)
            rgbi_geotiff = None

        # Export BRDF and topo corrected image
        if args.out:
            if line_start + col_start == 0:
                #output_name = args.od + os.path.splitext(os.path.basename(args.img))[0] + "_topo_brdf"
                output_name = args.od + os.path.splitext(
                    os.path.basename(args.img))[0] + args.out
                header_dict = hyObj.header_dict
                # Update header
                header_dict['wavelength'] = header_dict['wavelength'][
                    hyObj.bad_bands]
                header_dict['fwhm'] = header_dict['fwhm'][hyObj.bad_bands]
                #header_dict['bbl'] = header_dict['bbl'][hyObj.bad_bands]
                #if 'band names' in header_dict:
                #  del header_dict['band names']
                header_dict['bands'] = int(hyObj.bad_bands.sum())

                # clean ENVI header
                header_dict.pop('band names', None)
                header_dict.pop('correction factors', None)
                header_dict.pop('bbl', None)
                header_dict.pop('smoothing factors', None)

                writer = writeENVI(output_name, header_dict)
            writer.write_chunk(chunk, iterator.current_line,
                               iterator.current_column)
            if iterator.complete:
                writer.close()

        for i, trait in enumerate(traits):
            dstFile = args.od + os.path.splitext(
                os.path.basename(args.img))[0] + "_" + os.path.splitext(
                    os.path.basename(trait))[0] + ".tif"

            # Trait estimation preparation
            if line_start + col_start == 0:

                with open(trait) as json_file:
                    trait_model = json.load(json_file)

                intercept = np.array(trait_model['intercept'])
                coefficients = np.array(trait_model['coefficients'])
                transform = trait_model['transform']

                # Get list of wavelengths to compare against image wavelengths
                if len(trait_model['vector_norm_wavelengths']) == 0:
                    dst_waves = np.array(
                        trait_model['model_wavelengths']) * trait_wave_scaler
                else:
                    dst_waves = np.array(trait_model['vector_norm_wavelengths']
                                         ) * trait_wave_scaler

                dst_fwhm = np.array(trait_model['fwhm']) * trait_wave_scaler
                model_waves = np.array(
                    trait_model['model_wavelengths']) * trait_wave_scaler
                model_fwhm = [
                    dict(zip(dst_waves, dst_fwhm))[x] for x in model_waves
                ]

                vnorm_band_mask = [
                    x in zip(dst_waves, dst_fwhm) for x in trait_waves_fwhm
                ]
                model_band_mask = [
                    x in zip(model_waves, model_fwhm) for x in trait_waves_fwhm
                ]

                vnorm_band_mask = np.array(
                    vnorm_band_mask
                )  # convert list to numpy array, otherwise True/False will be treated as 1/0, which is the 2nd/1st band
                model_band_mask = np.array(
                    model_band_mask
                )  # convert list to numpy array, otherwise True/False will be treated as 1/0, which is the 2nd/1st band

                if trait_model['vector_norm']:
                    vnorm_scaler = trait_model["vector_scaler"]
                else:
                    vnorm_scaler = None

                # Initialize trait dictionary
                if i == 0:
                    trait_dict = {}
                trait_dict[i] = [
                    coefficients, intercept, trait_model['vector_norm'],
                    vnorm_scaler, vnorm_band_mask, model_band_mask, transform
                ]

                # Create geotiff driver
                driver = gdal.GetDriverByName("GTIFF")
                tiff = driver.Create(dstFile, hyObj.columns, hyObj.lines, 2,
                                     gdal.GDT_Float32)
                tiff.SetGeoTransform(hyObj.transform)
                tiff.SetProjection(hyObj.projection)
                tiff.GetRasterBand(1).SetNoDataValue(0)
                tiff.GetRasterBand(2).SetNoDataValue(0)
                del tiff, driver

            coefficients, intercept, vnorm, vnorm_scaler, vnorm_band_mask, model_band_mask, transform = trait_dict[
                i]

            chunk_t = np.copy(chunk_r)

            if vnorm:
                chunk_t[:, :, vnorm_band_mask] = vector_normalize_chunk(
                    chunk_t[:, :, vnorm_band_mask], vnorm_scaler)

            if transform == "log(1/R)":
                chunk_t[:, :, model_band_mask] = np.log(
                    1 / chunk_t[:, :, model_band_mask])

            trait_mean, trait_std = apply_plsr_chunk(
                chunk_t[:, :, model_band_mask], coefficients, intercept)

            # Change no data pixel values
            trait_mean[chunk_nodata_mask] = 0
            trait_std[chunk_nodata_mask] = 0

            # Write trait estimate to file
            trait_geotiff = gdal.Open(dstFile, gdal.GA_Update)
            trait_geotiff.GetRasterBand(1).WriteArray(trait_mean, col_start,
                                                      line_start)
            trait_geotiff.GetRasterBand(2).WriteArray(trait_std, col_start,
                                                      line_start)
            trait_geotiff = None
def main():
    '''
    Generate topographic and BRDF correction coefficients. Corrections can be calculated on individual images
    or groups of images.
    '''
    parser = argparse.ArgumentParser(
        description="In memory trait mapping tool.")
    parser.add_argument("--img",
                        help="Input image/directory pathname",
                        required=True,
                        nargs='*',
                        type=str)
    parser.add_argument("--obs",
                        help="Input observables pathname",
                        required=False,
                        nargs='*',
                        type=str)
    parser.add_argument("--od",
                        help="Ouput directory",
                        required=True,
                        type=str)
    parser.add_argument("--pref",
                        help="Coefficient filename prefix",
                        required=True,
                        type=str)
    parser.add_argument("--brdf",
                        help="Perform BRDF correction",
                        action='store_true')
    parser.add_argument("--kernels",
                        help="Li and Ross kernel types",
                        nargs=2,
                        type=str)
    parser.add_argument("--topo",
                        help="Perform topographic correction",
                        action='store_true')
    parser.add_argument("--mask",
                        help="Image mask type to use",
                        action='store_true')
    parser.add_argument("--mask_threshold",
                        help="Mask threshold value",
                        nargs='*',
                        type=float)
    parser.add_argument("--samp_perc",
                        help="Percent of unmasked pixels to sample",
                        type=float,
                        default=1.0)

    args = parser.parse_args()

    if not args.od.endswith("/"):
        args.od += "/"

    if len(args.img) == 1:
        image = args.img[0]

        #Load data objects memory
        if image.endswith(".h5"):
            hyObj = ht.openHDF(image, load_obs=True)
        else:
            hyObj = ht.openENVI(image)
            hyObj.load_obs(args.obs[0])
        hyObj.create_bad_bands([[300, 400], [1330, 1430], [1800, 1960],
                                [2450, 2600]])

        # no data  / ignored values varies by product
        hyObj.no_data = -0.9999

        hyObj.load_data()

        # Generate mask
        if args.mask:
            ir = hyObj.get_wave(850)
            red = hyObj.get_wave(665)
            ndvi = (ir - red) / (ir + red)
            #mask = (ndvi > 0.01) & (ir != hyObj.no_data)
            #mask = (ndvi > args.mask_threshold[0]) & (ndvi <= args.mask_threshold[1]) & (ir != hyObj.no_data)
            hyObj.mask = (ndvi > 0.01) & (ir != hyObj.no_data)
            del ir, red  #,ndvi
        else:
            hyObj.mask = np.ones((hyObj.lines, hyObj.columns)).astype(bool)
            print("Warning no mask specified, results may be unreliable!")

        # Generate cosine i and c1 image for topographic correction

        if args.topo:

            #  sensor_zn_mask  #(cosine_i >  ~5 deg

            topo_coeffs = {}
            topo_coeffs['wavelengths'] = hyObj.wavelengths[
                hyObj.bad_bands].tolist()
            topo_coeffs['c'] = []
            cos_i = calc_cosine_i(hyObj.solar_zn, hyObj.solar_az,
                                  hyObj.azimuth, hyObj.slope)
            c1 = np.cos(hyObj.solar_zn)
            c2 = np.cos(hyObj.slope)

            topomask = hyObj.mask & (cos_i > 0.12) & (hyObj.slope > 0.087)

        # Gernerate scattering kernel images for brdf correction
        if args.brdf:

            ndvi_thres = [0.005] + args.mask_threshold + [1.0]
            total_bin = len(args.mask_threshold) + 1
            brdfmask = np.ones(
                (total_bin, hyObj.lines, hyObj.columns)).astype(bool)

            for ibin in range(total_bin):
                brdfmask[ibin, :, :] = hyObj.mask & (
                    ndvi > ndvi_thres[ibin]) & (ndvi <= ndvi_thres[
                        ibin + 1]) & (hyObj.sensor_zn > np.radians(2))

            li, ross = args.kernels
            # Initialize BRDF dictionary

            brdf_coeffs_List = []  #initialize
            for ibin in range(total_bin):
                brdf_coeffs = {}
                brdf_coeffs['li'] = li
                brdf_coeffs['ross'] = ross
                brdf_coeffs['ndvi_lower_bound'] = ndvi_thres[ibin]
                brdf_coeffs['ndvi_upper_bound'] = ndvi_thres[ibin + 1]
                brdf_coeffs['wavelengths'] = hyObj.wavelengths[
                    hyObj.bad_bands].tolist()
                brdf_coeffs['fVol'] = []
                brdf_coeffs['fGeo'] = []
                brdf_coeffs['fIso'] = []
                brdf_coeffs_List.append(brdf_coeffs)

            k_vol = generate_volume_kernel(hyObj.solar_az,
                                           hyObj.solar_zn,
                                           hyObj.sensor_az,
                                           hyObj.sensor_zn,
                                           ross=ross)
            k_geom = generate_geom_kernel(hyObj.solar_az,
                                          hyObj.solar_zn,
                                          hyObj.sensor_az,
                                          hyObj.sensor_zn,
                                          li=li)
            k_finite = np.isfinite(k_vol) & np.isfinite(k_geom)

        # Cycle through the bands and calculate the topographic and BRDF correction coefficients
        print("Calculating image correction coefficients.....")
        iterator = hyObj.iterate(by='band')

        if args.topo or args.brdf:
            while not iterator.complete:
                band = iterator.read_next()
                mask_finite = band > 0.0001
                progbar(iterator.current_band + 1, len(hyObj.wavelengths), 100)
                #Skip bad bands
                if hyObj.bad_bands[iterator.current_band]:
                    # Generate topo correction coefficients
                    if args.topo:

                        topo_coeff = generate_topo_coeff_band(
                            band, topomask & mask_finite, cos_i)
                        topo_coeffs['c'].append(topo_coeff)

                    # Gernerate BRDF correction coefficients
                    if args.brdf:
                        if args.topo:
                            # Apply topo correction to current band
                            # SCS+C normalizes reflectance to ndir, but for input of brdf correction, it should be un-normalized
                            correctionFactor = (c2 + topo_coeff / c1) / (
                                cos_i + topo_coeff)
                            correctionFactor = correctionFactor * topomask + 1.0 * (
                                1 - topomask)  # only apply to orographic area
                            band = band * correctionFactor

                        for ibin in range(total_bin):
                            fVol, fGeo, fIso = generate_brdf_coeff_band(
                                band,
                                brdfmask[ibin, :, :] & mask_finite & k_finite,
                                k_vol, k_geom)
                            brdf_coeffs_List[ibin]['fVol'].append(fVol)
                            brdf_coeffs_List[ibin]['fGeo'].append(fGeo)
                            brdf_coeffs_List[ibin]['fIso'].append(fIso)
            print()
    '''       
    # Compute topographic and BRDF coefficients using data from multiple scenes
    elif len(args.img) > 1:
        
        hyObj_dict = {}
        sample_dict = {}
        sample_k_vol = []
        sample_k_geom = []
        sample_cos_i = []
        sample_c1 = []
        
        for i,image in enumerate(args.img):
            #Load data objects memory
            if image.endswith(".h5"):
                hyObj = ht.openHDF(image,load_obs = True)
            else:
                hyObj = ht.openENVI(image)
                hyObj.load_obs(args.obs[i])
            hyObj.create_bad_bands([[300,400],[1330,1430],[1800,1960],[2450,2600]])
            hyObj.no_data =-0.9999
            hyObj.load_data()
            
            # Generate mask
            if args.mask:
                ir = hyObj.get_wave(850)
                red = hyObj.get_wave(665)
                ndvi = (ir-red)/(ir+red)
                hyObj.mask  = (ndvi > .7) & (ir != hyObj.no_data)
                #mask  = (ndvi > 0.05)  & (ir != hyObj.no_data)
                del ir,red,ndvi
            else:
                hyObj.mask = np.ones((hyObj.lines,hyObj.columns)).astype(bool)
                print("Warning no mask specified, results may be unreliable!")
    
            # Generate sampling mask
            sampleArray = np.zeros(hyObj.mask.shape).astype(bool)
            idx = np.array(np.where(hyObj.mask == True)).T
            idxRand= idx[np.random.choice(range(len(idx)),int(len(idx)*args.samp_perc), replace = False)].T
            sampleArray[idxRand[0],idxRand[1]] = True
            sample_dict[i] = sampleArray
            
            # Initialize and store band iterator
            hyObj_dict[i] = copy.copy(hyObj).iterate(by = 'band')
            
            # Generate cosine i and c1 image for topographic correction
            if args.topo:   
                # Initialize topographic correction dictionary
                topo_coeffs = {}
                topo_coeffs['wavelengths'] = hyObj.wavelengths[hyObj.bad_bands].tolist() 
                topo_coeffs['c'] = []
                sample_cos_i += calc_cosine_i(hyObj.solar_zn, hyObj.solar_az, hyObj.azimuth , hyObj.slope)[sampleArray].tolist()
                sample_c1 += (np.cos(hyObj.solar_zn) * np.cos( hyObj.slope))[sampleArray].tolist()
            # Gernerate scattering kernel images for brdf correction
            if args.brdf:
                li,ross =  args.kernels
                # Initialize BRDF dictionary
                brdf_coeffs = {}
                brdf_coeffs['li'] = li
                brdf_coeffs['ross'] = ross
                brdf_coeffs['wavelengths'] = hyObj.wavelengths[hyObj.bad_bands].tolist() 
                brdf_coeffs['fVol'] = []
                brdf_coeffs['fGeo'] = []
                brdf_coeffs['fIso'] = []
                sample_k_vol += generate_volume_kernel(hyObj.solar_az,hyObj.solar_zn,hyObj.sensor_az,hyObj.sensor_zn, ross = ross)[sampleArray].tolist()
                sample_k_geom += generate_geom_kernel(hyObj.solar_az,hyObj.solar_zn,hyObj.sensor_az,hyObj.sensor_zn,li = li)[sampleArray].tolist()
        
            #del ndvi, topomask, brdfmask
        
        sample_k_vol = np.array(sample_k_vol)
        sample_k_geom = np.array(sample_k_geom)
        sample_cos_i = np.array(sample_cos_i)
        sample_c1= np.array(sample_c1)
        
        # Calculate bandwise correction coefficients
        print("Calculating image correction coefficients.....")
        current_progress = 0
        for w,wave in enumerate(hyObj.wavelengths):
            progbar(current_progress, len(hyObj.wavelengths) * len(args.img), 100)
            wave_samples = []
            for i,image in enumerate(args.img):
                wave_samples +=  hyObj_dict[i].read_next()[sample_dict[i]].tolist()
                current_progress+=1
            
            if hyObj.bad_bands[hyObj_dict[i].current_band]:
                wave_samples = np.array(wave_samples)
                # Generate cosine i and c1 image for topographic correction
                if args.topo:    
                    topo_coeff  = generate_topo_coeff_band(wave_samples,[True for x in wave_samples],sample_cos_i)
                    topo_coeffs['c'].append(topo_coeff)
                    correctionFactor = (sample_c1 + topo_coeff)/(sample_cos_i + topo_coeff)
                    wave_samples = wave_samples* correctionFactor
                # Gernerate scattering kernel images for brdf correction
                if args.brdf:
                    fVol,fGeo,fIso = generate_brdf_coeff_band(wave_samples,[True for x in wave_samples],sample_k_vol,sample_k_geom)                  
                    brdf_coeffs['fVol'].append(fVol)
                    brdf_coeffs['fGeo'].append(fGeo)
                    brdf_coeffs['fIso'].append(fIso)
    '''

    # Export coefficients to JSON
    if args.topo:
        topo_json = "%s%s_topo_coeffs.json" % (args.od, args.pref)
        with open(topo_json, 'w') as outfile:
            json.dump(topo_coeffs, outfile)
    if args.brdf:
        for ibin in range(total_bin):
            brdf_json = "%s%s_brdf_coeffs_%s.json" % (args.od, args.pref,
                                                      str(ibin + 1))
            with open(brdf_json, 'w') as outfile:
                json.dump(brdf_coeffs_List[ibin], outfile)
Beispiel #3
0
def main():

    parser = argparse.ArgumentParser(
        description="Convert NEON AOP H5 to ENVI format")
    parser.add_argument("--img",
                        help="Input image pathname",
                        required=True,
                        type=str)
    parser.add_argument("--out",
                        help="Output directory",
                        required=True,
                        type=str)
    parser.add_argument("--obs",
                        help="Ouput observables",
                        required=False,
                        action='store_true')

    args = parser.parse_args()

    if not args.out.endswith("/"):
        args.out += "/"

    #Load data objects memory
    hyObj = ht.openHDF(args.img)
    hyObj.load_data()

    iterator = hyObj.iterate(by='chunk')
    pixels_processed = 0

    output_name = args.out + os.path.basename(os.path.splitext(args.img)[0])
    writer = writeENVI(output_name, ENVI_header_from_hdf(hyObj))

    while not iterator.complete:
        chunk = iterator.read_next()
        pixels_processed += chunk.shape[0] * chunk.shape[1]
        progbar(pixels_processed, hyObj.columns * hyObj.lines, 100)

        writer.write_chunk(chunk, iterator.current_line,
                           iterator.current_column)
        if iterator.complete:
            writer.close()

    if args.obs:
        print("Exporting observables....")
        obs_header = ENVI_header_from_hdf(hyObj)
        obs_header['bands'] = 11
        obs_header['band_names'] = [
            'path length', 'to-sensor azimuth', 'to-sensor zenith',
            'to-sun azimuth', 'to-sun zenith', 'phase', 'slope', 'aspect',
            'cosine i', 'UTC time'
        ]
        obs_header['wavelength units'] = np.nan
        obs_header['data type'] = 4
        writer = writeENVI(output_name + "_obs_ort", obs_header)

        hdfObj = h5py.File(args.img, 'r')
        base_key = list(hdfObj.keys())[0]
        metadata = hdfObj[base_key]["Reflectance"]["Metadata"]

        writer.write_band(metadata['Ancillary_Imagery']['Path_Length'][:, :],
                          0)
        writer.write_band(metadata['to-sensor_Azimuth_Angle'][:, :], 1)
        writer.write_band(metadata['to-sensor_Zenith_Angle'][:, :], 2)
        writer.write_band(
            np.ones((hyObj.lines, hyObj.columns)) *
            metadata['Logs']['Solar_Azimuth_Angle'].value, 3)
        writer.write_band(
            np.ones((hyObj.lines, hyObj.columns)) *
            metadata['Logs']['Solar_Zenith_Angle'].value, 4)
        writer.write_band(metadata['Ancillary_Imagery']['Slope'][:, :], 6)
        writer.write_band(metadata['Ancillary_Imagery']['Aspect'][:, :], 7)

        writer.close()
        hdfObj.close()
def main():
    '''
    Generate topographic and BRDF correction coefficients. Corrections can be calculated on individual images
    or groups of images.
    '''
    parser = argparse.ArgumentParser(description="In memory trait mapping tool.")
    parser.add_argument("--img", help="Input image/directory pathname", required=True, nargs='*', type=str)
    parser.add_argument("--obs", help="Input observables pathname", required=False, nargs='*', type=str)
    parser.add_argument("--od", help="Ouput directory", required=True, type=str)
    parser.add_argument("--pref", help="Coefficient filename prefix", required=True, type=str)
    parser.add_argument("--brdf", help="Perform BRDF correction", action='store_true')
    parser.add_argument("--kernels", help="Li and Ross kernel types", nargs=2, type=str)
    parser.add_argument("--topo", help="Perform topographic correction", action='store_true')
    parser.add_argument("--mask", help="Image mask type to use", action='store_true')
    parser.add_argument("--mask_threshold", help="Mask threshold value", nargs='*', type=float)
    parser.add_argument("--samp_perc", help="Percent of unmasked pixels to sample", type=float, default=1.0)
    parser.add_argument("--agmask", help="ag / urban mask file", required=False, type=str)
    parser.add_argument("--topo_sep", help="In multiple image mode, perform topographic correction in a image-based fasion", action='store_true')
    parser.add_argument("--mass_center", help="Use mass center to be the center coordinate, default is geometric center. It is only used in BRDF correction", action='store_true')
    parser.add_argument("--check_flight", help="Check abnormal flight lines if group mode BRDF correction is performed", action='store_true')

    parser.add_argument("--dynamicbin", help="Total Number of dynamic NDVI bins, with higher priority than mask_threshold", type=int, required=False)
    parser.add_argument("--buffer_neon", help="neon buffer", action='store_true')

    args = parser.parse_args()

    if not args.od.endswith("/"):
        args.od += "/"

    if len(args.img) == 1:
        image = args.img[0]

        # Load data objects memory
        if image.endswith(".h5"):
            hyObj = ht.openHDF(image, load_obs=True)
            smoothing_factor = np.ones(hyObj.bands)
        else:
            hyObj = ht.openENVI(image)
            hyObj.load_obs(args.obs[0])

            smoothing_factor = hyObj.header_dict[NAME_FIELD_SMOOTH]
            # CORR product has smoothing factor, and all bands are converted back to uncorrected / unsmoothed version by dividing the corr/smooth factors
            if isinstance(smoothing_factor, (list, tuple, np.ndarray)):
                smoothing_factor = np.array(smoothing_factor)
            # REFL version
            else:
                smoothing_factor = np.ones(hyObj.bands)

        hyObj.create_bad_bands(BAD_RANGE)

        # no data  / ignored values varies by product
        # hyObj.no_data = NO_DATA_VALUE

        hyObj.load_data()

        # Generate mask
        if args.mask:
            ir = hyObj.get_wave(BAND_IR_NM)
            red = hyObj.get_wave(BAND_RED_NM)
            ndvi = (1.0 * ir - red) / (1.0 * ir + red)

            ag_mask = np.array([0])
            if args.agmask:
                ag_mask = np.fromfile(args.agmask, dtype=np.uint8).reshape((hyObj.lines, hyObj.columns))

            if args.buffer_neon:
                buffer_edge = sig.convolve2d(ir <= 0.5 * hyObj.no_data, make_disc_for_buffer(30), mode='same', fillvalue=1)
                ag_mask = ag_mask or (buffer_edge > 0)

            hyObj.mask = (ndvi > NDVI_MIN_THRESHOLD) & (ndvi < NDVI_MAX_THRESHOLD) & (ir != hyObj.no_data) & (ag_mask == 0)

            del ir, red  # ,ndvi
        else:
            hyObj.mask = np.ones((hyObj.lines, hyObj.columns)).astype(bool)
            print("Warning no mask specified, results may be unreliable!")

        # Generate cosine i and c1 image for topographic correction

        if args.topo:

            topo_coeffs = {}
            topo_coeffs['wavelengths'] = hyObj.wavelengths[hyObj.bad_bands].tolist()
            topo_coeffs['c'] = []
            topo_coeffs['slope'] = []
            topo_coeffs['intercept'] = []
            cos_i = calc_cosine_i(hyObj.solar_zn, hyObj.solar_az, hyObj.aspect, hyObj.slope)
            c1 = np.cos(hyObj.solar_zn)
            c2 = np.cos(hyObj.slope)

            terrain_msk = (cos_i > COSINE_I_MIN_THRESHOLD) & (hyObj.slope > SLOPE_MIN_THRESHOLD)
            topomask = hyObj.mask  # & (cos_i > 0.12)  & (hyObj.slope > 0.087)

        # Generate scattering kernel images for brdf correction
        if args.brdf:

            if args.dynamicbin:
                total_bin = args.dynamicbin
                perc_range = DYN_NDVI_BIN_HIGH_PERC - DYN_NDVI_BIN_LOW_PERC + 1
                ndvi_break_dyn_bin = np.percentile(ndvi[ndvi > 0], np.arange(DYN_NDVI_BIN_LOW_PERC, DYN_NDVI_BIN_HIGH_PERC + 1, perc_range / (total_bin - 1)))
                ndvi_thres = sorted([NDVI_BIN_MIN_THRESHOLD] + ndvi_break_dyn_bin.tolist() + [NDVI_BIN_MAX_THRESHOLD])

            else:
                if args.mask_threshold:
                    ndvi_thres = [NDVI_BIN_MIN_THRESHOLD] + args.mask_threshold + [NDVI_BIN_MAX_THRESHOLD]
                else:
                    ndvi_thres = [NDVI_BIN_MIN_THRESHOLD, NDVI_BIN_MAX_THRESHOLD]

            ndvi_thres = sorted(list(set(ndvi_thres)))
            total_bin = len(ndvi_thres) - 1
            brdfmask = np.ones((total_bin, hyObj.lines, hyObj.columns)).astype(bool)

            for ibin in range(total_bin):
                brdfmask[ibin, :, :] = hyObj.mask & (ndvi > ndvi_thres[ibin]) & (ndvi <= ndvi_thres[ibin + 1]) & (hyObj.sensor_zn > np.radians(SENSOR_ZENITH_MIN_DEG))

            li, ross = args.kernels
            # Initialize BRDF dictionary

            brdf_coeffs_List = []  # initialize
            brdf_mask_stat = np.zeros(total_bin)

            for ibin in range(total_bin):
                brdf_mask_stat[ibin] = np.count_nonzero(brdfmask[ibin, :, :])

                brdf_coeffs = {}
                brdf_coeffs['li'] = li
                brdf_coeffs['ross'] = ross
                brdf_coeffs['ndvi_lower_bound'] = ndvi_thres[ibin]
                brdf_coeffs['ndvi_upper_bound'] = ndvi_thres[ibin + 1]
                brdf_coeffs['wavelengths'] = hyObj.wavelengths[hyObj.bad_bands].tolist()
                brdf_coeffs['fVol'] = []
                brdf_coeffs['fGeo'] = []
                brdf_coeffs['fIso'] = []
                brdf_coeffs_List.append(brdf_coeffs)

            k_vol = generate_volume_kernel(hyObj.solar_az, hyObj.solar_zn, hyObj.sensor_az, hyObj.sensor_zn, ross=ross)
            k_geom = generate_geom_kernel(hyObj.solar_az, hyObj.solar_zn, hyObj.sensor_az, hyObj.sensor_zn, li=li)
            k_finite = np.isfinite(k_vol) & np.isfinite(k_geom)

            if args.mass_center:
                coord_center_method = 2
            else:
                coord_center_method = 1

            img_center_info = get_center_info(hyObj, coord_center_method)
            print(img_center_info)
            # csv_img_center_info = np.array([os.path.basename(args.img[0]).split('_')[0]]+list(img_center_info))[:,np.newaxis]
            csv_img_center_info = np.array([os.path.basename(args.img[0])] + [args.pref] * 2 + list(img_center_info))[:, np.newaxis]
            np.savetxt("%s%s_lat_sza.csv" % (args.od, args.pref), csv_img_center_info.T, header="image_name,uniq_name,uniq_name_short,LON,LAT,solar_zn_rad,ssolar_zn_deg", delimiter=',', fmt='%s', comments='')

        # Cycle through the bands and calculate the topographic and BRDF correction coefficients
        print("Calculating image correction coefficients.....")
        iterator = hyObj.iterate(by='band')

        if args.topo or args.brdf:
            while not iterator.complete:

                if hyObj.bad_bands[iterator.current_band + 1]:  # load data to RAM, if it is a goog band
                    band = iterator.read_next()
                    band = band / smoothing_factor[iterator.current_band]
                    band_msk = (band > REFL_MIN_THRESHOLD) & (band < REFL_MAX_THRESHOLD)

                else:  # similar to .read_next(), but do not load data to RAM, if it is a bad band
                    iterator.current_band += 1
                    if iterator.current_band == hyObj.bands - 1:
                        iterator.complete = True

                progbar(iterator.current_band + 1, len(hyObj.wavelengths), 100)
                # Skip bad bands
                if hyObj.bad_bands[iterator.current_band]:
                    # Generate topo correction coefficients
                    if args.topo:

                        topomask_b = topomask & band_msk

                        if np.count_nonzero(topomask_b & terrain_msk) > MIN_SAMPLE_COUNT_TOPO:
                            topo_coeff, reg_slope, reg_intercept = generate_topo_coeff_band(band, topomask_b & terrain_msk, cos_i, non_negative=True)
                            # topo_coeff, reg_slope, reg_intercept= generate_topo_coeff_band(band,topomask_b & terrain_msk,cos_i)
                        else:
                            topo_coeff = FLAT_COEFF_C
                            reg_slope, reg_intercept = (-9999, -9999)
                            print("fill with FLAT_COEFF_C")

                        topo_coeffs['c'].append(topo_coeff)
                        topo_coeffs['slope'].append(reg_slope)
                        topo_coeffs['intercept'].append(reg_intercept)

                    # Gernerate BRDF correction coefficients
                    if args.brdf:
                        if args.topo:
                            # Apply topo correction to current bands
                            correctionFactor = (c2 * c1 + topo_coeff) / (cos_i + topo_coeff)
                            correctionFactor = correctionFactor * topomask_b + 1.0 * (1 - topomask_b)  # only apply to orographic area
                            band = band * correctionFactor

                        for ibin in range(total_bin):

                            if brdf_mask_stat[ibin] < MIN_SAMPLE_COUNT:
                                continue

                            band_msk_new = (band > REFL_MIN_THRESHOLD) & (band < REFL_MAX_THRESHOLD)

                            if np.count_nonzero(brdfmask[ibin, :, :] & band_msk & k_finite & band_msk_new) < MIN_SAMPLE_COUNT:
                                brdf_mask_stat[ibin] = DIAGNO_NAN_OUTPUT
                                continue

                            fVol, fGeo, fIso = generate_brdf_coeff_band(band, brdfmask[ibin, :, :] & band_msk & k_finite & band_msk_new, k_vol, k_geom)
                            brdf_coeffs_List[ibin]['fVol'].append(fVol)
                            brdf_coeffs_List[ibin]['fGeo'].append(fGeo)
                            brdf_coeffs_List[ibin]['fIso'].append(fIso)

    # Compute topographic and BRDF coefficients using data from multiple scenes
    elif len(args.img) > 1:

        image_uniq_name_list, image_uniq_name_list_short = get_uniq_img_name(args.img)

        if args.check_flight:
            args.brdf = True
            args.topo = True
            args.topo_sep = True
            print("Automatically enable TOPO and BRDF mode if 'check_flight' is enabled. ")

        if args.brdf:

            li, ross = args.kernels

            ndvi_thres_complete = False
            if args.dynamicbin:
                total_bin = args.dynamicbin
                perc_range = DYN_NDVI_BIN_HIGH_PERC - DYN_NDVI_BIN_LOW_PERC + 1
                # ndvi_break_dyn_bin= np.percentile(ndvi[ndvi>0], np.arange(DYN_NDVI_BIN_LOW_PERC,DYN_NDVI_BIN_HIGH_PERC+1,perc_range//(total_bin-2)))
                ndvi_thres = [NDVI_BIN_MIN_THRESHOLD] + [None] * (total_bin - 1) + [NDVI_BIN_MAX_THRESHOLD]
                print("NDVI bins:", ndvi_thres)
            else:
                ndvi_thres_complete = True
                if args.mask_threshold:
                    ndvi_thres = sorted([NDVI_BIN_MIN_THRESHOLD] + args.mask_threshold + [NDVI_BIN_MAX_THRESHOLD])
                    total_bin = len(args.mask_threshold) + 1
                else:
                    ndvi_thres = [NDVI_BIN_MIN_THRESHOLD, NDVI_BIN_MAX_THRESHOLD]
                    total_bin = 1

            brdf_coeffs_List = []  # initialize
            brdf_mask_stat = np.zeros(total_bin)

            for ibin in range(total_bin):
                brdf_coeffs = {}
                brdf_coeffs['li'] = li
                brdf_coeffs['ross'] = ross
                brdf_coeffs['ndvi_lower_bound'] = None  # ndvi_thres[ibin]
                brdf_coeffs['ndvi_upper_bound'] = None  # ndvi_thres[ibin+1]
                # brdf_coeffs['wavelengths'] = hyObj.wavelengths[hyObj.bad_bands].tolist()
                brdf_coeffs['fVol'] = []
                brdf_coeffs['fGeo'] = []
                brdf_coeffs['fIso'] = []
                brdf_coeffs['flight_box_avg_sza'] = []
                brdf_coeffs_List.append(brdf_coeffs)

            if args.mass_center:
                coord_center_method = 2
            else:
                coord_center_method = 1
            csv_group_center_info = np.empty((7, 0))

        hyObj_dict = {}
        sample_dict = {}
        idxRand_dict = {}
        sample_k_vol = []
        sample_k_geom = []
        sample_cos_i = []
        sample_c1 = []
        sample_slope = []
        sample_ndvi = []
        sample_index = [0]
        sample_img_tag = []  # record which image that sample is drawn from
        sub_total_sample_size = 0
        ndvi_mask_dict = {}
        image_smooth = []
        band_subset_outlier = []
        hyObj_pointer_dict_list = []

        for i, image in enumerate(args.img):
            # Load data objects memory
            if image.endswith(".h5"):
                hyObj = ht.openHDF(image, load_obs=True)
                smoothing_factor = np.ones(hyObj.bands)
                image_smooth += [smoothing_factor]
            else:
                hyObj = ht.openENVI(image)
                hyObj.load_obs(args.obs[i])

                smoothing_factor = hyObj.header_dict[NAME_FIELD_SMOOTH]
                # CORR product has smoothing factor, and all bands are converted back to uncorrected / unsmoothed version by dividing the corr/smooth factors
                if isinstance(smoothing_factor, (list, tuple, np.ndarray)):
                    smoothing_factor = np.array(smoothing_factor)
                # REFL version
                else:
                    smoothing_factor = np.ones(hyObj.bands)
                image_smooth += [smoothing_factor]

            hyObj.create_bad_bands(BAD_RANGE)
            hyObj.no_data = NO_DATA_VALUE
            hyObj.load_data()
            hyObj_pointer_dict_list = hyObj_pointer_dict_list + [copy.copy(hyObj)]

            if args.brdf:
                img_center_info = get_center_info(hyObj, coord_center_method)
                print(img_center_info)
                csv_img_center_info = np.array([os.path.basename(args.img[i])] + [image_uniq_name_list[i]] + [image_uniq_name_list_short[i]] + list(img_center_info))[:, np.newaxis]
                csv_group_center_info = np.hstack((csv_group_center_info, csv_img_center_info))
            # continues

            # Generate mask
            if args.mask:
                ir = hyObj.get_wave(BAND_IR_NM)
                red = hyObj.get_wave(BAND_RED_NM)
                ndvi = (1.0 * ir - red) / (1.0 * ir + red)

                ag_mask = 0
                if args.buffer_neon:
                    buffer_edge = sig.convolve2d(ir <= 0.5 * hyObj.no_data, make_disc_for_buffer(30), mode='same', fillvalue=1)
                    ag_mask = ag_mask or (buffer_edge > 0)

                hyObj.mask = (ndvi > NDVI_MIN_THRESHOLD) & (ndvi <= NDVI_MAX_THRESHOLD) & (ir != hyObj.no_data) & (ag_mask == 0)
                hyObj.mask.tofile(args.od + '/' + args.pref + str(i) + '_msk.bin')
                del ir, red  # ,ndvi
            else:
                hyObj.mask = np.ones((hyObj.lines, hyObj.columns)).astype(bool)
                print("Warning no mask specified, results may be unreliable!")

            # Generate sampling mask
            sampleArray = np.zeros(hyObj.mask.shape).astype(bool)
            idx = np.array(np.where(hyObj.mask == True)).T

            # np.random.seed(0)  # just for test

            idxRand = idx[np.random.choice(range(len(idx)), int(len(idx) * args.samp_perc), replace=False)].T  # actually used
            sampleArray[idxRand[0], idxRand[1]] = True
            sample_dict[i] = sampleArray
            idxRand_dict[i] = idxRand

            print(idxRand.shape)
            sub_total_sample_size += idxRand.shape[1]
            sample_index = sample_index + [sub_total_sample_size]

            # Initialize and store band iterator
            hyObj_dict[i] = copy.copy(hyObj).iterate(by='band')

            # Generate cosine i and slope samples
            sample_cos_i += calc_cosine_i(hyObj.solar_zn, hyObj.solar_az, hyObj.aspect, hyObj.slope)[sampleArray].tolist()
            sample_slope += (hyObj.slope)[sampleArray].tolist()

            # Generate c1 samples for topographic correction
            if args.topo:
                # Initialize topographic correction dictionary
                topo_coeffs = {}
                topo_coeffs['wavelengths'] = hyObj.wavelengths[hyObj.bad_bands].tolist()
                topo_coeffs['c'] = []
                topo_coeffs['slope'] = []
                topo_coeffs['intercept'] = []
                sample_c1 += (np.cos(hyObj.solar_zn) * np.cos(hyObj.slope))[sampleArray].tolist()

            # Gernerate scattering kernel samples for brdf correction
            if args.brdf or args.check_flight:

                sample_ndvi += (ndvi)[sampleArray].tolist()
                sample_img_tag += [i + 1] * idxRand.shape[1]  # start from 1

                for ibin in range(total_bin):
                    brdf_coeffs_List[ibin]['wavelengths'] = hyObj.wavelengths[hyObj.bad_bands].tolist()

                sample_k_vol += generate_volume_kernel(hyObj.solar_az, hyObj.solar_zn, hyObj.sensor_az, hyObj.sensor_zn, ross=ross)[sampleArray].tolist()
                sample_k_geom += generate_geom_kernel(hyObj.solar_az, hyObj.solar_zn, hyObj.sensor_az, hyObj.sensor_zn, li=li)[sampleArray].tolist()

            # del ndvi, topomask, brdfmask
            if args.mask:
                del ndvi

            # find outliers, initialize band_subset_outlier
            if args.check_flight and i == 0:
                for wave_i in RGBIM_BAND_CHECK_OUTLIERS:
                    band_num = int(hyObj.wave_to_band(wave_i))
                    band_subset_outlier = band_subset_outlier + [band_num]  # zero based

        sample_k_vol = np.array(sample_k_vol)
        sample_k_geom = np.array(sample_k_geom)
        sample_cos_i = np.array(sample_cos_i)
        sample_c1 = np.array(sample_c1)
        sample_slope = np.array(sample_slope)
        sample_ndvi = np.array(sample_ndvi)
        sample_img_tag = np.array(sample_img_tag)

        sample_topo_msk = (sample_cos_i > COSINE_I_MIN_THRESHOLD) & (sample_slope > SLOPE_MIN_THRESHOLD)

        if args.brdf or args.check_flight:
            group_summary_info = [args.pref] * 3 + np.mean(csv_group_center_info[3:, :].astype(np.float), axis=1).tolist()
            csv_group_center_info = np.insert(csv_group_center_info.T, 0, group_summary_info, axis=0)
            np.savetxt("%s%s_lat_sza.csv" % (args.od, args.pref), csv_group_center_info, header="image_name,uniq_name,uniq_name_short,LON,LAT,solar_zn_rad,solar_zn_deg", delimiter=',', fmt='%s', comments='')

            if not ndvi_thres_complete:
                ndvi_break_dyn_bin = np.percentile(sample_ndvi[sample_ndvi > 0], np.arange(DYN_NDVI_BIN_LOW_PERC, DYN_NDVI_BIN_HIGH_PERC + 1, perc_range / (total_bin - 1)))

                ndvi_thres = sorted([NDVI_BIN_MIN_THRESHOLD] + ndvi_break_dyn_bin.tolist() + [NDVI_BIN_MAX_THRESHOLD])
                ndvi_thres = sorted(list(set(ndvi_thres)))  # remove duplicates
                total_bin = len(ndvi_thres) - 1

            for ibin in range(total_bin):
                brdf_coeffs_List[ibin]['flight_box_avg_sza'] = float(csv_group_center_info[0, -1])
                # print('last',csv_group_center_info[0,-1])

                brdf_coeffs_List[ibin]['ndvi_lower_bound'] = ndvi_thres[ibin]
                brdf_coeffs_List[ibin]['ndvi_upper_bound'] = ndvi_thres[ibin + 1]

        if args.topo_sep:
            topo_coeff_list = []
            for _ in range(len(args.img)):
                topo_coeff_list += [{'c': [], 'slope':[], 'intercept':[], 'wavelengths':hyObj.wavelengths[hyObj.bad_bands].tolist()}]

        # initialize BRDF coeffs
        if args.brdf:
            for ibin in range(total_bin):
                ndvi_mask = (sample_ndvi > brdf_coeffs_List[ibin]['ndvi_lower_bound']) & (sample_ndvi <= brdf_coeffs_List[ibin]['ndvi_upper_bound'])
                ndvi_mask_dict[ibin] = ndvi_mask
                count_ndvi = np.count_nonzero(ndvi_mask)
                brdf_mask_stat[ibin] = count_ndvi

            # initialize arrays for BRDF coefficient estimation diagnostic files
            total_image = len(args.img)

            # r-squared array
            r_squared_array = np.ndarray((total_bin * (total_image + 1), 3 + len(hyObj.wavelengths)), dtype=object)  # total + flightline by flightline
            r_squared_array[:] = 0

            r_squared_array[:, 0] = (0.5 * (np.array(ndvi_thres[:-1]) + np.array(ndvi_thres[1:]))).tolist() * (total_image + 1)
            r_squared_array[:total_bin, 1] = 'group'
            r_squared_array[total_bin:, 1] = np.repeat(np.array([os.path.basename(x) for x in args.img]), total_bin, axis=0)

            r2_header = 'NDVI_Bin_Center,Flightline,Sample_Size,' + ','.join('B' + str(wavename) for wavename in hyObj.wavelengths)

            # RMSE array
            rmse_array = np.copy(r_squared_array)

            # BRDF coefficient array, volumetric+geometric+isotopic
            brdf_coeff_array = np.ndarray((total_bin + 8, 3 + 3 * len(hyObj.wavelengths)), dtype=object)
            brdf_coeff_array[:] = 0
            brdf_coeff_array[:total_bin, 0] = 0.5 * (np.array(ndvi_thres[:-1]) + np.array(ndvi_thres[1:]))

            brdf_coeffs_header = 'NDVI_Bin_Center,Sample_Size,' + ','.join('B' + str(wavename) + '_vol' for wavename in hyObj.wavelengths) + ',' + ','.join('B' + str(wavename) + '_geo' for wavename in hyObj.wavelengths) + ',' + ','.join('B' + str(wavename) + '_iso' for wavename in hyObj.wavelengths)
            brdf_coeff_array[total_bin + 1:total_bin + 7, 0] = ['slope', 'intercept', 'r_value', 'p_value', 'std_err', 'rmse']

        # find outliers
        img_tag_used = (sample_img_tag > -1)  # All True, will be changed if there is any abnormal lines.
        if args.check_flight:
            print("Searching for fight line outliers.....")

            outlier_dict = kernel_ndvi_outlier_search(band_subset_outlier, sample_k_vol, sample_k_geom, sample_c1, sample_cos_i, sample_slope, sample_ndvi, sample_topo_msk, sample_img_tag, idxRand_dict, hyObj_pointer_dict_list, image_smooth)
            outlier_json = "%s%s_outliers.json" % (args.od, args.pref)
            with open(outlier_json, 'w') as outfile:
                json.dump(outlier_dict, outfile)

            if outlier_dict["b1"]["outlier_count"] > 0:
                outlier_image_bool = np.array(outlier_dict["b1"]["outlier_image_bool"]).astype(bool)
                img_tag_used = ~outlier_image_bool[np.arange(len(args.img))[sample_img_tag - 1]]
                print(np.unique(sample_img_tag[img_tag_used]))

                if outlier_dict["b1"]["outlier_count"] > len(args.img) / 2:
                    print("More than half of the lines are abnormal lines, please check the information in {}{}_outliers.json. ".format(args.od, args.pref))
                    print("Flight line outliers checking Finishes.")
                    return  # exit the script, halt the procedure of coefficients estimation.

            print("Flight line outliers checking finishes.")

        else:
            print("Use all fight lines.....")
            
            # wave9_samples = np.empty((9,0),float)
            # if
            # singleband_kernel_ndvi_outlier_search(wave_samples, args.od,args.pref)

        # Calculate bandwise correction coefficients
        print("Calculating image correction coefficients.....")
        current_progress = 0

        for w, wave in enumerate(hyObj.wavelengths):
            progbar(current_progress, len(hyObj.wavelengths) * len(args.img), 100)
            wave_samples = []
            for i, image in enumerate(args.img):

                if hyObj.bad_bands[hyObj_dict[i].current_band + 1]:  # load data to RAM, if it is a goog band
                    wave_samples += hyObj_dict[i].read_next()[sample_dict[i]].tolist()

                else:  # similar to .read_next(), but do not load data to RAM, if it is a bad band
                    hyObj_dict[i].current_band += 1
                    if hyObj_dict[i].current_band == hyObj_dict[i].bands - 1:
                        hyObj_dict[i].complete = True

                current_progress += 1

            if hyObj.bad_bands[hyObj_dict[i].current_band]:

                wave_samples = np.array(wave_samples)

                for i_img_tag in range(len(args.img)):
                    img_tag_true = sample_img_tag == i_img_tag + 1
                    wave_samples[img_tag_true] = wave_samples[img_tag_true] / image_smooth[i_img_tag][w]

                # Generate cosine i and c1 image for topographic correction
                if args.topo:
                    if not args.topo_sep:
                        topo_coeff, coeff_slope, coeff_intercept = generate_topo_coeff_band(wave_samples, (wave_samples > REFL_MIN_THRESHOLD) & (wave_samples < REFL_MAX_THRESHOLD) & sample_topo_msk, sample_cos_i, non_negative=True)
                        topo_coeffs['c'].append(topo_coeff)
                        topo_coeffs['slope'].append(coeff_slope)
                        topo_coeffs['intercept'].append(coeff_intercept)
                        correctionFactor = (sample_c1 + topo_coeff) / (sample_cos_i + topo_coeff)
                        correctionFactor = correctionFactor * sample_topo_msk + 1.0 * (1 - sample_topo_msk)
                        wave_samples = wave_samples * correctionFactor
                    else:
                        for i in range(len(args.img)):
                            wave_samples_sub = wave_samples[sample_index[i]:sample_index[i + 1]]
                            sample_cos_i_sub = sample_cos_i[sample_index[i]:sample_index[i + 1]]
                            # sample_slope_sub = sample_slope[sample_index[i]:sample_index[i + 1]]
                            sample_c1_sub = sample_c1[sample_index[i]:sample_index[i + 1]]

                            sample_topo_msk_sub = sample_topo_msk[sample_index[i]:sample_index[i + 1]]

                            if np.count_nonzero(sample_topo_msk_sub) > MIN_SAMPLE_COUNT_TOPO:
                                topo_coeff, coeff_slope, coeff_intercept = generate_topo_coeff_band(wave_samples_sub, (wave_samples_sub > REFL_MIN_THRESHOLD) & (wave_samples_sub < REFL_MAX_THRESHOLD) & sample_topo_msk_sub, sample_cos_i_sub, non_negative=True)
                            else:
                                topo_coeff = FLAT_COEFF_C

                            topo_coeff_list[i]['c'].append(topo_coeff)
                            topo_coeff_list[i]['slope'].append(coeff_slope)
                            topo_coeff_list[i]['intercept'].append(coeff_intercept)

                            correctionFactor = (sample_c1_sub + topo_coeff) / (sample_cos_i_sub + topo_coeff)
                            correctionFactor = correctionFactor * sample_topo_msk_sub + 1.0 * (1 - sample_topo_msk_sub)
                            wave_samples[sample_index[i]:sample_index[i + 1]] = wave_samples_sub * correctionFactor

                # Gernerate scattering kernel images for brdf correction
                if args.brdf:

                    wave_samples = wave_samples[img_tag_used]
                    temp_mask = (wave_samples > REFL_MIN_THRESHOLD) & (wave_samples < REFL_MAX_THRESHOLD) & np.isfinite(sample_k_vol[img_tag_used]) & np.isfinite(sample_k_geom[img_tag_used])
                    temp_mask = temp_mask & (sample_cos_i[img_tag_used] > COSINE_I_MIN_THRESHOLD) & (sample_slope[img_tag_used] > SAMPLE_SLOPE_MIN_THRESHOLD)

                    for ibin in range(total_bin):

                        # skip BINs that has not enough samples in diagnostic output
                        if brdf_mask_stat[ibin] < MIN_SAMPLE_COUNT or np.count_nonzero(temp_mask) < MIN_SAMPLE_COUNT:
                            r_squared_array[range(ibin, total_bin * (total_image + 1), total_bin), w + 3] = DIAGNO_NAN_OUTPUT
                            rmse_array[range(ibin, total_bin * (total_image + 1), total_bin), w + 3] = DIAGNO_NAN_OUTPUT
                            brdf_mask_stat[ibin] = brdf_mask_stat[ibin] + DIAGNO_NAN_OUTPUT
                            continue

                        if np.count_nonzero(temp_mask & ndvi_mask_dict[ibin][img_tag_used]) < MIN_SAMPLE_COUNT:
                            fVol, fGeo, fIso = (0, 0, 1)
                        else:
                            fVol, fGeo, fIso = generate_brdf_coeff_band(wave_samples, temp_mask & ndvi_mask_dict[ibin][img_tag_used], sample_k_vol[img_tag_used], sample_k_geom[img_tag_used])

                        mask_sub = temp_mask & ndvi_mask_dict[ibin][img_tag_used]
                        r_squared_array[ibin, 2] = wave_samples[mask_sub].shape[0]
                        est_r2, sample_nn, rmse_total = cal_r2(wave_samples[mask_sub], sample_k_vol[img_tag_used][mask_sub], sample_k_geom[img_tag_used][mask_sub], [fVol, fGeo, fIso])
                        r_squared_array[ibin, w + 3] = est_r2
                        rmse_array[ibin, w + 3] = rmse_total
                        rmse_array[ibin, 2] = r_squared_array[ibin, 2]

                        brdf_coeff_array[ibin, 1] = wave_samples[mask_sub].shape[0]

                        # update diagnostic information scene by scene
                        for img_order in range(total_image):

                            img_mask_sub = (sample_img_tag[img_tag_used] == (img_order + 1)) & mask_sub

                            est_r2, sample_nn, rmse_bin = cal_r2(wave_samples[img_mask_sub], sample_k_vol[img_tag_used][img_mask_sub], sample_k_geom[img_tag_used][img_mask_sub], [fVol, fGeo, fIso])
                            r_squared_array[ibin + (img_order + 1) * total_bin, w + 3] = est_r2
                            r_squared_array[ibin + (img_order + 1) * total_bin, 2] = max(sample_nn, int(r_squared_array[ibin + (img_order + 1) * total_bin, 2]))  # update many times

                            rmse_array[ibin + (img_order + 1) * total_bin, w + 3] = rmse_bin
                            rmse_array[ibin + (img_order + 1) * total_bin, 2] = r_squared_array[ibin + (img_order + 1) * total_bin, 2]

                        brdf_coeffs_List[ibin]['fVol'].append(fVol)
                        brdf_coeffs_List[ibin]['fGeo'].append(fGeo)
                        brdf_coeffs_List[ibin]['fIso'].append(fIso)

                        # save the same coefficient information in diagnostic arrays
                        brdf_coeff_array[ibin, 2 + w] = fVol
                        brdf_coeff_array[ibin, 2 + w + len(hyObj.wavelengths)] = fGeo
                        brdf_coeff_array[ibin, 2 + w + 2 * len(hyObj.wavelengths)] = fIso

                    # update array for BRDF output diagnostic files
                    mid_ndvi_list = brdf_coeff_array[:total_bin, 0].astype(np.float)

                    # check linearity( NDVI as X v.s. kernel coefficients as Y ), save to diagnostic file, BIN by BIN, and wavelength by wavelength
                    if np.count_nonzero(brdf_coeff_array[:, 2 + w]) > 3:
                        # volumetric coefficients
                        temp_y = brdf_coeff_array[:total_bin, 2 + w].astype(np.float)
                        slope_ndvi_bin, intercept_ndvi_bin, r_value_ndvi_bin, p_value_ndvi_bin, std_err_ndvi_bin, rmse_ndvi_bin = cal_r2_single(mid_ndvi_list[(mid_ndvi_list > BRDF_VEG_lower_bound) & (temp_y != 0)], temp_y[(mid_ndvi_list > BRDF_VEG_lower_bound) & (temp_y != 0)])
                        brdf_coeff_array[total_bin + 1:total_bin + 7, 2 + w] = slope_ndvi_bin, intercept_ndvi_bin, r_value_ndvi_bin, p_value_ndvi_bin, std_err_ndvi_bin, rmse_ndvi_bin

                        # geometric coefficients
                        temp_y = brdf_coeff_array[:total_bin, 2 + w + 1 * len(hyObj.wavelengths)].astype(np.float)
                        slope_ndvi_bin, intercept_ndvi_bin, r_value_ndvi_bin, p_value_ndvi_bin, std_err_ndvi_bin, rmse_ndvi_bin = cal_r2_single(mid_ndvi_list[(mid_ndvi_list > BRDF_VEG_lower_bound) & (temp_y != 0)], temp_y[(mid_ndvi_list > BRDF_VEG_lower_bound) & (temp_y != 0)])
                        brdf_coeff_array[total_bin + 1:total_bin + 7, 2 + w + len(hyObj.wavelengths)] = slope_ndvi_bin, intercept_ndvi_bin, r_value_ndvi_bin, p_value_ndvi_bin, std_err_ndvi_bin, rmse_ndvi_bin

                        # isotropic coefficients
                        temp_y = brdf_coeff_array[:total_bin, 2 + w + 2 * len(hyObj.wavelengths)].astype(np.float)
                        slope_ndvi_bin, intercept_ndvi_bin, r_value_ndvi_bin, p_value_ndvi_bin, std_err_ndvi_bin, rmse_ndvi_bin = cal_r2_single(mid_ndvi_list[(mid_ndvi_list > BRDF_VEG_lower_bound) & (temp_y != 0)], temp_y[(mid_ndvi_list > BRDF_VEG_lower_bound) & (temp_y != 0)])
                        brdf_coeff_array[total_bin + 1:total_bin + 7, 2 + w + 2 * len(hyObj.wavelengths)] = slope_ndvi_bin, intercept_ndvi_bin, r_value_ndvi_bin, p_value_ndvi_bin, std_err_ndvi_bin, rmse_ndvi_bin

    # Export coefficients to JSON
    if args.topo:
        if (not args.topo_sep) or (len(args.img) == 1):
            topo_json = "%s%s_topo_coeffs.json" % (args.od, args.pref)
            with open(topo_json, 'w') as outfile:
                json.dump(topo_coeffs, outfile)
        else:
            for i_img in range(len(args.img)):
                # filename_pref = (os.path.basename(args.img[i_img])).split('_')[0]
                filename_pref = image_uniq_name_list[i_img]
                topo_json = "%s%s_topo_coeffs.json" % (args.od, filename_pref)
                with open(topo_json, 'w') as outfile:
                    json.dump(topo_coeff_list[i_img], outfile)

    if args.brdf:

        if len(args.img) > 1:
            # In grouping mode, save arrays for BRDF diagnostic information to ascii files
            np.savetxt("%s%s_brdf_coeffs_r2.csv" % (args.od, args.pref), r_squared_array, header=r2_header, delimiter=',', fmt='%s', comments='')
            np.savetxt("%s%s_brdf_coeffs_rmse.csv" % (args.od, args.pref), rmse_array, header=r2_header, delimiter=',', fmt='%s', comments='')
            np.savetxt("%s%s_brdf_coeffs_fit.csv" % (args.od, args.pref), brdf_coeff_array, header=brdf_coeffs_header, delimiter=',', fmt='%s', comments='')

        if total_bin > 0:
            for ibin in range(total_bin):
                if brdf_mask_stat[ibin] < MIN_SAMPLE_COUNT:
                    continue
                brdf_json = "%s%s_brdf_coeffs_%s.json" % (args.od, args.pref, str(ibin + 1))
                with open(brdf_json, 'w') as outfile:
                    json.dump(brdf_coeffs_List[ibin], outfile)
        else:
            brdf_json = "%s%s_brdf_coeffs_1.json" % (args.od, args.pref)
            with open(brdf_json, 'w') as outfile:
                json.dump(brdf_coeffs, outfile)
Beispiel #5
0
def main():
    '''
    Perform in-memory trait estimation.
    '''
    
    parser = argparse.ArgumentParser(description = "In memory trait mapping tool.")
    parser.add_argument("-img", help="Input image pathname",required=True, type = str)
    parser.add_argument("--obs", help="Input observables pathname", required=False, type = str)
    parser.add_argument("--out", help="Output full corrected image", required=False, type = str)
    parser.add_argument("-od", help="Output directory for all resulting products", required=True, type = str)
    parser.add_argument("--brdf", help="Perform BRDF correction",type = str, default = '')
    parser.add_argument("--topo", help="Perform topographic correction", type = str, default = '')
    parser.add_argument("--mask", help="Image mask type to use", action='store_true')
    parser.add_argument("--mask_threshold", help="Mask threshold value", nargs = '*', type = float)
    parser.add_argument("--rgbim", help="Export RGBI +Mask image.", action='store_true')
    parser.add_argument("-coeffs", help="Trait coefficients directory", required=False, type = str)
    parser.add_argument("-nodata", help="New value to assign for no_data values", required=False, type = float, default =-9999)
    parser.add_argument("-smooth", help="BRDF smooth methods L: Linear regression; W: Weighted linear regression; I: Linear interpolation", required=False  , choices=['L', 'W', 'I'])
    parser.add_argument("-sszn", help="standard solar zenith angle (degree)", required=False, type = float)
    
    parser.add_argument("--buffer_neon", help="neon buffer", action='store_true')
    
    parser.add_argument("-boxsszn", help="Use box average standard solar zenith angle (degree)", action='store_true')
        
    args = parser.parse_args()

    traits = glob.glob("%s/*.json" % args.coeffs)

    std_solar_zn = None  #float(args.sszn)/180*np.pi
    
    #Load data objects memory
    if args.img.endswith(".h5"):
        hyObj = ht.openHDF(args.img,load_obs = True)
        smoothing_factor = 1
    else:
        hyObj = ht.openENVI(args.img)
        
        smoothing_factor = hyObj.header_dict[NAME_FIELD_SMOOTH]
        if isinstance(smoothing_factor, (list, tuple, np.ndarray)):
        # CORR product has smoothing factor, and all bands are converted back to uncorrected / unsmoothed version by dividing the corr/smooth factors
            smoothing_factor = np.array(smoothing_factor)
        else:
        # REFL version
            smoothing_factor = 1
        
    if (len(args.topo) != 0) | (len(args.brdf) != 0):
        hyObj.load_obs(args.obs)
    if not args.od.endswith("/"):
        args.od+="/"
    hyObj.create_bad_bands(BAD_RANGE)
    
    # no data  / ignored values varies by product
    hyObj.no_data = NO_DATA_VALUE
    
    hyObj.load_data()  


    # Generate mask
    extra_mask=True
    
    if args.mask:
        ir = hyObj.get_wave(BAND_IR_NM)
        red = hyObj.get_wave(BAND_RED_NM)
        ndvi = (ir-red)/(ir+red)
        
        
        if args.buffer_neon: 
          print("Creating buffer of image edge...")
          buffer_edge=sig.convolve2d(ir <= 0.5*hyObj.no_data,make_disc_for_buffer(30), mode = 'same', fillvalue=1)              
          #ag_mask = ag_mask or (buffer_edge>0)
          extra_mask = (buffer_edge == 0)
        
        hyObj.mask = ((ndvi > NDVI_APPLIED_BIN_MIN_THRESHOLD) & (ir != hyObj.no_data)) 

        del ir, red #,ndvi
    else:
        hyObj.mask = np.ones((hyObj.lines,hyObj.columns)).astype(bool)
        print("Warning no mask specified, results may be unreliable!")

    # Generate cosine i and c1 image for topographic correction
    if len(args.topo) != 0:
        with open( args.topo) as json_file:  
            topo_coeffs = json.load(json_file)
            
        topo_coeffs['c'] = np.array(topo_coeffs['c'])   
        cos_i =  calc_cosine_i(hyObj.solar_zn, hyObj.solar_az, hyObj.aspect , hyObj.slope)
        c1 = np.cos(hyObj.solar_zn)
        c2 = np.cos(hyObj.slope)
        
        topomask = hyObj.mask & (cos_i > COSINE_I_MIN_THRESHOLD) & (hyObj.slope > SLOPE_MIN_THRESHOLD)

    
    # Gernerate scattering kernel images for brdf correction
    if len(args.brdf) != 0:
        brdf_coeffs_List = []

        ndvi_thres_complete = False
        if (args.mask_threshold):
          total_bin = len(args.mask_threshold) + 1
          ndvi_thres = [NDVI_APPLIED_BIN_MIN_THRESHOLD] + args.mask_threshold + [NDVI_APPLIED_BIN_MAX_THRESHOLD]
          ndvi_thres_complete = True
        else:
        # read NDVI binning info from existing json files
          #total_bin=1
          #ndvi_thres = [NDVI_APPLIED_BIN_MIN_THRESHOLD, NDVI_APPLIED_BIN_MAX_THRESHOLD]
          total_bin = len(glob.glob(args.brdf + '_brdf_coeffs_*.json'))
          ndvi_thres = [None] * total_bin + [NDVI_APPLIED_BIN_MAX_THRESHOLD] 
          
    
        if args.smooth is None:
          brdfmask = np.ones((total_bin, hyObj.lines,hyObj.columns )).astype(bool)
            
        first_effective_ibin = 0  # in case some bins are missing due to small sample size
        
        for ibin in range(total_bin):
        
          if not os.path.exists(args.brdf + '_brdf_coeffs_' + str(ibin + 1) + '.json'):
            brdf_coeffs_List.append(None)
            print('No ' + args.brdf + '_brdf_coeffs_' + str(ibin + 1) + '.json')
            if args.smooth is None:
              brdfmask[ibin, :, :] = False
            continue
            
          first_effective_ibin = ibin
          
          with open(args.brdf + '_brdf_coeffs_' + str(ibin + 1) + '.json') as json_file:  
            brdf_coeffs = json.load(json_file)
            brdf_coeffs['fVol'] = np.array(brdf_coeffs['fVol'])
            brdf_coeffs['fGeo'] = np.array(brdf_coeffs['fGeo'])
            brdf_coeffs['fIso'] = np.array(brdf_coeffs['fIso'])
            brdf_coeffs_List.append(brdf_coeffs)
            if not ndvi_thres_complete:
              ndvi_thres[ibin] = max(float(brdf_coeffs['ndvi_lower_bound']), NDVI_APPLIED_BIN_MIN_THRESHOLD)
              ndvi_thres[ibin + 1] = min(float(brdf_coeffs['ndvi_upper_bound']), NDVI_APPLIED_BIN_MAX_THRESHOLD)
            
            if std_solar_zn is None:
              print(std_solar_zn)
              if args.sszn:
                std_solar_zn = float(args.sszn) / 180 * np.pi
              elif args.boxsszn:
                std_solar_zn = float(brdf_coeffs['flight_box_avg_sza']) / 180 * np.pi
              else:
                std_solar_zn = -9999
 

          if  args.smooth is None:          
            brdfmask[ibin, :, :] = hyObj.mask & (ndvi > ndvi_thres[ibin]) & (ndvi <= ndvi_thres[ibin + 1])

        if args.smooth is not None:
            n_wave = len(brdf_coeffs_List[first_effective_ibin]['fVol']) 
            fvol_y_n = np.zeros((total_bin, n_wave), dtype=np.float)
            fiso_y_n = np.zeros((total_bin, n_wave), dtype=np.float)
            fgeo_y_n = np.zeros((total_bin, n_wave), dtype=np.float)
            upper_bound_y_n = np.zeros(total_bin)
            lower_bound_y_n = np.zeros(total_bin)
            
            min_lower_bound = 1000
            max_upper_bound = BRDF_VEG_upper_bound 
            
            for ibin in range(total_bin):
              if brdf_coeffs_List[ibin] is None:
                continue
              else:
                n_wave =len(brdf_coeffs_List[ibin]['fVol']) 
                fvol_y_n[ibin, :] = brdf_coeffs_List[ibin]["fVol"]
                fgeo_y_n[ibin, :] = brdf_coeffs_List[ibin]["fGeo"]
                fiso_y_n[ibin, :] = brdf_coeffs_List[ibin]["fIso"]
                upper_bound_y_n[ibin] = float(brdf_coeffs_List[ibin]["ndvi_upper_bound"])
                lower_bound_y_n[ibin] = float(brdf_coeffs_List[ibin]["ndvi_lower_bound"]) 
                min_lower_bound = min(min_lower_bound, lower_bound_y_n[ibin])
                if lower_bound_y_n[ibin] > BRDF_VEG_upper_bound:
                  max_upper_bound = upper_bound_y_n[ibin] + 0.0001
             
            mid_pnt = (upper_bound_y_n + lower_bound_y_n) / 2.0
            
            if args.smooth == 'I':            
                #poly_bin = (upper_bound_y_n<=BRDF_VEG_upper_bound) & (lower_bound_y_n>=BRDF_VEG_lower_bound)
                poly_bin = (upper_bound_y_n <= max_upper_bound) & (lower_bound_y_n >= BRDF_VEG_lower_bound)
                old_bin_low = (lower_bound_y_n < BRDF_VEG_lower_bound)
                old_bin_hi = (upper_bound_y_n > max_upper_bound)
                n_old_bin = np.count_nonzero(old_bin_low) + np.count_nonzero(old_bin_hi)
                outside_list = np.where(old_bin_low | old_bin_hi)[0].tolist() 

                mid_x = mid_pnt[poly_bin]
                yy_vol = fvol_y_n[poly_bin, :]
                yy_geo = fgeo_y_n[poly_bin, :]
                yy_iso =  fiso_y_n[poly_bin, :]

                brdfmask = np.zeros((n_old_bin + 1, hyObj.lines,hyObj.columns)).astype(bool)  # ones?
                for order_bin, ibin in enumerate(outside_list):
                #for ibin in range(n_old_bin):
                  if (upper_bound_y_n[ibin] == lower_bound_y_n[ibin]) and (upper_bound_y_n[ibin] == 0.0):
                  # missing bins will be used for interpolation / extrapolation, add mask to the last bin-mask
                    brdfmask[n_old_bin, :, :] = (brdfmask[n_old_bin, :, :]) | (hyObj.mask & (ndvi > ndvi_thres[ibin]) & (ndvi <= ndvi_thres[ibin + 1]))
                  else:
                  # bins outsize bound will be kept without any extrapolation
                    
                    brdfmask[order_bin, :, :] = hyObj.mask & (ndvi > ndvi_thres[ibin]) & (ndvi <= ndvi_thres[ibin + 1])

                brdfmask[n_old_bin, :, :] = (brdfmask[n_old_bin, :, :]) | (hyObj.mask & (ndvi >= BRDF_VEG_lower_bound) & (ndvi <= max_upper_bound))              
            
            elif args.smooth == 'L' or args.smooth == 'W':
                
                #poly_bin = (upper_bound_y_n<=BRDF_VEG_upper_bound) & (lower_bound_y_n>BRDF_VEG_lower_bound)
                poly_bin = (upper_bound_y_n <= max_upper_bound) & (lower_bound_y_n > BRDF_VEG_lower_bound)
                old_bin = (upper_bound_y_n <= max(BRDF_VEG_lower_bound,min_lower_bound))
                n_old_bin = np.count_nonzero(old_bin)
                
                if args.smooth == 'W':
                  sample_size = get_sample_size(args.brdf + '_brdf_coeffs_r2.csv', total_bin)
                  if sample_size is not None:
                    #wight_thres = np.percentile(sample_size, 50)
                    #sample_size = np.clip(sample_size, 0 , wight_thres)
                    weight_n =  (1.0 * sample_size) / np.sum(sample_size)
                  else:
                    weight_n = np.ones(total_bin)
                else:
                  weight_n = np.ones(total_bin)
                  
                weight_n = weight_n[poly_bin]
                
                mid_x = mid_pnt[poly_bin]
                yy_vol = fvol_y_n[poly_bin, :]
                yy_geo = fgeo_y_n[poly_bin, :]
                yy_iso =  fiso_y_n[poly_bin, :]
                
                coeff_list_vol = np.zeros((n_wave, 2))
                coeff_list_geo = np.zeros((n_wave, 2))
                coeff_list_iso = np.zeros((n_wave, 2))
                
                for j in range(n_wave):
                    yy_vol_ = yy_vol[:, j]
                    yy_geo_ = yy_geo[:, j]
                    yy_iso_ = yy_iso[:, j]
                    
                    coeff_vol = P.polyfit(mid_x, yy_vol_, 1, full=False,w=weight_n)  # order: x^0, x^1, x^2 ... 
                    coeff_geo = P.polyfit(mid_x, yy_geo_, 1, full=False,w=weight_n)
                    coeff_iso = P.polyfit(mid_x, yy_iso_, 1, full=False,w=weight_n)

                    coeff_list_vol[j, :] = coeff_vol
                    coeff_list_geo[j, :] = coeff_geo
                    coeff_list_iso[j, :] = coeff_iso
                    
                brdfmask = np.ones((n_old_bin + 1, hyObj.lines, hyObj.columns)).astype(bool)
                for ibin in range(n_old_bin):
                  brdfmask[ibin, :, :] = hyObj.mask & (ndvi > ndvi_thres[ibin]) & (ndvi <= ndvi_thres[ibin + 1])   
     
                brdfmask[n_old_bin, :, :] = hyObj.mask & (ndvi > NDVI_APPLIED_BIN_MIN_THRESHOLD) & (ndvi < 1.0)                


        
        k_vol = generate_volume_kernel(hyObj.solar_az,hyObj.solar_zn, hyObj.sensor_az, hyObj.sensor_zn, ross = brdf_coeffs_List[first_effective_ibin]['ross'])
        k_geom = generate_geom_kernel(hyObj.solar_az,hyObj.solar_zn, hyObj.sensor_az, hyObj.sensor_zn, li = brdf_coeffs_List[first_effective_ibin]['li'])
        
        print('std_solar_zn', std_solar_zn)
        if std_solar_zn == -9999:
        # NADIR without solor zenith angle normalization
          k_vol_nadir = generate_volume_kernel(hyObj.solar_az, hyObj.solar_zn, hyObj.sensor_az, 0, ross = brdf_coeffs_List[first_effective_ibin]['ross'])
          k_geom_nadir = generate_geom_kernel(hyObj.solar_az, hyObj.solar_zn, hyObj.sensor_az, 0, li = brdf_coeffs_List[first_effective_ibin]['li']) 
        else:       
        # use solor zenith angle normalization, either from flight box average (json), or from user specification
          k_vol_nadir = generate_volume_kernel(np.pi, std_solar_zn, hyObj.sensor_az, 0, ross = brdf_coeffs_List[first_effective_ibin]['ross'])
          k_geom_nadir = generate_geom_kernel(np.pi, std_solar_zn, hyObj.sensor_az, 0, li = brdf_coeffs_List[first_effective_ibin]['li'])

    if len(traits) != 0:
      
        #Cycle through the chunks and apply topo, brdf, vnorm,resampling and trait estimation steps
        print("Calculating values for %s traits....." % len(traits))
        
        
        # Cycle through trait models and gernerate resampler
        trait_waves_all = []
        trait_fwhm_all = []
        
        for i,trait in enumerate(traits):
            with open(trait) as json_file:  
                trait_model = json.load(json_file)
             
            # Check if wavelength units match
            if trait_model['wavelength_units'] == 'micrometers':
                trait_wave_scaler = 10**3
            else:
                trait_wave_scaler = 1    
            
            # Get list of wavelengths to compare against image wavelengths
            if len(trait_model['vector_norm_wavelengths']) == 0:
                trait_waves_all += list(np.array(trait_model['model_wavelengths'])*trait_wave_scaler)
            else:
                trait_waves_all += list(np.array(trait_model['vector_norm_wavelengths'])*trait_wave_scaler)
            
            trait_fwhm_all += list(np.array(trait_model['fwhm'])* trait_wave_scaler)        
               
        # List of all unique pairs of wavelengths and fwhm    
        trait_waves_fwhm = list(set([x for x in zip(trait_waves_all,trait_fwhm_all)]))
        trait_waves_fwhm.sort(key = lambda x: x[0])

        # Create a single set of resampling coefficients for all wavelength and fwhm combos
        #resampling_coeffs = est_transform_matrix(hyObj.wavelengths[hyObj.bad_bands],[x for (x,y) in trait_waves_fwhm] ,hyObj.fwhm[hyObj.bad_bands],[y for (x,y) in trait_waves_fwhm],1)

        # if wavelengths match, no need to resample
        # check_wave_match_result = check_wave_match(hyObj, [x for (x, y) in trait_waves_fwhm])
        check_wave_match_result = check_wave_match(hyObj, trait_waves_fwhm)
        print("match_flag", check_wave_match_result['flag']) 
        if (check_wave_match_result['flag']):
            match_flag = True
        else:
            match_flag = False
            center_interpolate = check_wave_match_result['center_interpolate']
            if not center_interpolate:
                resampling_coeffs = est_transform_matrix(hyObj.wavelengths[hyObj.bad_bands], [x for (x, y) in trait_waves_fwhm], hyObj.fwhm[:hyObj.bands][hyObj.bad_bands], [y for (x, y) in trait_waves_fwhm], 2)  # 2       


    hyObj.wavelengths = hyObj.wavelengths[hyObj.bad_bands]
    
    pixels_processed = 0
    iterator = hyObj.iterate(by = 'chunk',chunk_size = (CHUNK_EDGE_SIZE,hyObj.columns))

    while not iterator.complete:  
        chunk = iterator.read_next()  
        #chunk_nodata_mask = chunk[:,:, BAND_NO_DATA] == hyObj.no_data  
        chunk_nodata_mask = chunk[:,:, BAND_NO_DATA] <= 0.5 * hyObj.no_data
        pixels_processed += chunk.shape[0] * chunk.shape[1]
        progbar(pixels_processed, hyObj.columns * hyObj.lines, 100)
        
        chunk = chunk/smoothing_factor

        # Chunk Array indices
        line_start =iterator.current_line 
        line_end = iterator.current_line + chunk.shape[0]
        col_start = iterator.current_column
        col_end = iterator.current_column + chunk.shape[1]
        
        # Apply TOPO correction 
        if len(args.topo) != 0:
            cos_i_chunk = cos_i[line_start:line_end,col_start:col_end]
            c1_chunk = c1[line_start:line_end,col_start:col_end]
            c2_chunk = c2[line_start:line_end,col_start:col_end]
            topomask_chunk = topomask[line_start:line_end,col_start:col_end,np.newaxis]
            correctionFactor = (c2_chunk[:,:,np.newaxis]*c1_chunk[:,:,np.newaxis]+topo_coeffs['c']  )/(cos_i_chunk[:,:,np.newaxis] + topo_coeffs['c'])
            correctionFactor = correctionFactor*topomask_chunk + 1.0*(1-topomask_chunk)
            chunk = chunk[:,:,hyObj.bad_bands]* correctionFactor
        else:
            chunk = chunk[:,:,hyObj.bad_bands] 
        
        # Apply BRDF correction 
        if len(args.brdf) != 0:
            # Get scattering kernel for chunks
            k_vol_nadir_chunk = k_vol_nadir[line_start:line_end,col_start:col_end]
            k_geom_nadir_chunk = k_geom_nadir[line_start:line_end,col_start:col_end]
            k_vol_chunk = k_vol[line_start:line_end,col_start:col_end]
            k_geom_chunk = k_geom[line_start:line_end,col_start:col_end]
        
            n_wavelength = brdf_coeffs_List[first_effective_ibin]['fVol'].shape[0]
            new_k_vol = np.zeros((chunk.shape[0],chunk.shape[1],n_wavelength),dtype=np.float32)
            new_k_geom = np.zeros((chunk.shape[0],chunk.shape[1],n_wavelength),dtype=np.float32)
            new_k_iso = np.zeros((chunk.shape[0],chunk.shape[1],n_wavelength),dtype=np.float32)

            
            if args.smooth is None:
                for ibin in range(total_bin):

                  if brdf_coeffs_List[ibin] is None:
                    continue

                  veg_mask = brdfmask[ibin,line_start:line_end,col_start:col_end][:,:,np.newaxis]
                  
                  new_k_vol +=  brdf_coeffs_List[ibin]['fVol'] * veg_mask
                  new_k_geom += brdf_coeffs_List[ibin]['fGeo'] * veg_mask
                  new_k_iso += brdf_coeffs_List[ibin]['fIso'] * veg_mask
                  
            else:
              if args.smooth=='I':  

                for ibin in range(n_old_bin):
                  if brdf_coeffs_List[outside_list[ibin]] is None:
                    continue
                  else:
                    veg_mask = brdfmask[ibin,line_start:line_end,col_start:col_end][:,:,np.newaxis]
                    
                    #new_k_vol +=  brdf_coeffs_List[ibin]['fVol'] * veg_mask
                    #new_k_geom += brdf_coeffs_List[ibin]['fGeo'] * veg_mask
                    #new_k_iso += brdf_coeffs_List[ibin]['fIso'] * veg_mask
                    new_k_vol +=  brdf_coeffs_List[outside_list[ibin]]['fVol'] * veg_mask
                    new_k_geom += brdf_coeffs_List[outside_list[ibin]]['fGeo'] * veg_mask
                    new_k_iso += brdf_coeffs_List[outside_list[ibin]]['fIso'] * veg_mask
                    
                veg_mask = brdfmask[n_old_bin,line_start:line_end,col_start:col_end][:,:,np.newaxis]
                ndvi_sub = ndvi[line_start:line_end,col_start:col_end]
                
                new_k_vol = interpol_1d_kernel_coeff(ndvi_sub, veg_mask,new_k_vol, mid_x, yy_vol)
                new_k_geom = interpol_1d_kernel_coeff(ndvi_sub, veg_mask,new_k_geom, mid_x, yy_geo) 
                new_k_iso = interpol_1d_kernel_coeff(ndvi_sub, veg_mask,new_k_iso, mid_x, yy_iso)   
                
              elif args.smooth=='L' or args.smooth=='W':                  
                veg_mask = brdfmask[n_old_bin,line_start:line_end,col_start:col_end][:,:,np.newaxis]
                ndvi_sub = ndvi[line_start:line_end,col_start:col_end]

                new_k_vol = interpol_kernel_coeff(ndvi_sub, veg_mask,new_k_vol, coeff_list_vol)
                new_k_geom = interpol_kernel_coeff(ndvi_sub, veg_mask,new_k_geom,coeff_list_geo)
                new_k_iso = interpol_kernel_coeff(ndvi_sub, veg_mask,new_k_iso,coeff_list_iso) 

                
            # Apply brdf correction 
            # eq 5. Weyermann et al. IEEE-TGARS 2015)
                       
            brdf = np.einsum('ijk,ij-> ijk', new_k_vol,k_vol_chunk) + np.einsum('ijk,ij-> ijk', new_k_geom,k_geom_chunk)  + new_k_iso
            brdf_nadir = np.einsum('ijk,ij-> ijk', new_k_vol,k_vol_nadir_chunk) + np.einsum('ijk,ij-> ijk', new_k_geom,k_geom_nadir_chunk)  + new_k_iso
            correctionFactor = brdf_nadir/brdf  #*veg_total+(1.0-veg_total)
            correctionFactor[brdf == 0.0] = 1.0
            chunk = chunk   * correctionFactor
            
        
        #Reassign no data values
        chunk[chunk_nodata_mask,:] = args.nodata
        

        if len(traits)>0:

            if match_flag==False:  
            # Resample chunk or interpolate chunk
                if center_interpolate:
                # interpolate chunk, only offset appears
                    interp_func =  interp1d(hyObj.wavelengths, chunk, kind='cubic', axis=2, fill_value="extrapolate")
                    chunk_r = interp_func(np.array([x for (x, y) in trait_waves_fwhm]))
                else:
                # fhhm and center of band do not match
                    chunk_r = np.dot(chunk, resampling_coeffs) 
            # subset of chunk
            else:
                chunk_r = chunk[:,:,check_wave_match_result['index']]

        interp_func =  interp1d(hyObj.wavelengths, chunk, kind='cubic', axis=2, fill_value="extrapolate")

        
        # Export RGBIM image
        if args.rgbim:
            dstFile = args.od + os.path.splitext(os.path.basename(args.img))[0] + '_rgbim.tif'
            nband_rgbim = len(RGBIM_BAND)
            
            if line_start + col_start == 0:
                driver = gdal.GetDriverByName("GTIFF")
                tiff = driver.Create(dstFile,hyObj.columns,hyObj.lines,nband_rgbim+1,gdal.GDT_Float32)
                tiff.SetGeoTransform(hyObj.transform)
                tiff.SetProjection(hyObj.projection)
                for band in range(1,nband_rgbim+2):
                    tiff.GetRasterBand(band).SetNoDataValue(args.nodata)
                tiff.GetRasterBand(nband_rgbim+1).WriteArray(hyObj.mask  & extra_mask )

                del tiff,driver
                
            # Write rgbi chunk
            rgbi_geotiff = gdal.Open(dstFile, gdal.GA_Update)
            
            for i,wave in enumerate(RGBIM_BAND,start=1):
                    band = hyObj.wave_to_band(wave)                    
                    rgbi_geotiff.GetRasterBand(i).WriteArray(chunk[:,:,band], col_start, line_start)
                    
            rgbi_geotiff = None
        
        # Export BRDF and topo corrected image
        if args.out:
            if line_start + col_start == 0:
                output_name = args.od + os.path.splitext(os.path.basename(args.img))[0] + args.out 

                if isinstance( hyObj.header_dict, dict): 
                    # ENVI
                    header_dict =hyObj.header_dict
                    header_dict['wavelength']= header_dict['wavelength'][hyObj.bad_bands]
                else:    
                    #HDF5
                    header_dict = h5_make_header_dict(hyObj) # bad bands removed


                # Update header
                
                header_dict['fwhm'] = header_dict['fwhm'][hyObj.bad_bands]
                #header_dict['bbl'] = header_dict['bbl'][hyObj.bad_bands]
                if 'band names' in header_dict:  
                  del header_dict['band names']
                header_dict['bands'] = int(hyObj.bad_bands.sum())
                
                # clean ENVI header
                header_dict.pop('band names', None)
                header_dict.pop(NAME_FIELD_SMOOTH, None)
                header_dict.pop('bbl', None)                
                header_dict.pop('smoothing factors', None)               
                
                writer = writeENVI(output_name,header_dict)
            writer.write_chunk(chunk,iterator.current_line,iterator.current_column)
            if iterator.complete:
                writer.close()
                
        for i,trait in enumerate(traits):
            dstFile = args.od + os.path.splitext(os.path.basename(args.img))[0] +"_" +os.path.splitext(os.path.basename(trait))[0] +".tif"
            
            # Trait estimation preparation
            if line_start + col_start == 0:
                
                with open(trait) as json_file:  
                    trait_model = json.load(json_file)
       
                intercept = np.array(trait_model['intercept'])
                coefficients = np.array(trait_model['coefficients'])
                transform = trait_model['transform']
                
                # Get list of wavelengths to compare against image wavelengths
                if len(trait_model['vector_norm_wavelengths']) == 0:
                    dst_waves = np.array(trait_model['model_wavelengths'])*trait_wave_scaler
                else:
                    dst_waves = np.array(trait_model['vector_norm_wavelengths'])*trait_wave_scaler
                
                dst_fwhm = np.array(trait_model['fwhm'])* trait_wave_scaler
                model_waves = np.array(trait_model['model_wavelengths'])* trait_wave_scaler
                model_fwhm = [dict(zip(dst_waves, dst_fwhm))[x] for x in model_waves]
                
                vnorm_band_mask = [x in zip(dst_waves,dst_fwhm) for x in trait_waves_fwhm]
                model_band_mask = [x in zip(model_waves,model_fwhm) for x in trait_waves_fwhm]
                
                vnorm_band_mask = np.array(vnorm_band_mask)  # convert list to numpy array, otherwise True/False will be treated as 1/0, which is the 2nd/1st band 
                model_band_mask = np.array(model_band_mask)   # convert list to numpy array, otherwise True/False will be treated as 1/0, which is the 2nd/1st band

                if trait_model['vector_norm']:
                    vnorm_scaler = trait_model["vector_scaler"]
                else:
                    vnorm_scaler = None

                # Initialize trait dictionary
                if i == 0:
                    trait_dict = {}
                trait_dict[i] = [coefficients,intercept,trait_model['vector_norm'],vnorm_scaler,vnorm_band_mask,model_band_mask,transform]
        
                # Create geotiff driver
                driver = gdal.GetDriverByName("GTIFF")
                
                if args.buffer_neon:
                  tiff = driver.Create(dstFile,hyObj.columns,hyObj.lines,3,gdal.GDT_Float32, options=["INTERLEAVE=BAND"])
                  tiff.GetRasterBand(3).WriteArray(hyObj.mask  & extra_mask )
                  tiff.GetRasterBand(3).SetDescription("Buffered Mask")
                else:
                  tiff = driver.Create(dstFile,hyObj.columns,hyObj.lines,2,gdal.GDT_Float32, options=["INTERLEAVE=BAND"])  # , "TILED=YES" ,"COMPRESS=LZW"
                
                tiff.SetGeoTransform(hyObj.transform)
                tiff.SetProjection(hyObj.projection)
                tiff.GetRasterBand(1).SetNoDataValue(args.nodata)
                tiff.GetRasterBand(2).SetNoDataValue(args.nodata)
                tiff.GetRasterBand(1).SetDescription("Model Mean")
                tiff.GetRasterBand(2).SetDescription("Model Standard Deviation")

                del tiff,driver
            
            coefficients,intercept,vnorm,vnorm_scaler,vnorm_band_mask,model_band_mask,transform = trait_dict[i]

            chunk_t =np.copy(chunk_r)

            if vnorm:                    
                chunk_t[:,:,vnorm_band_mask] = vector_normalize_chunk(chunk_t[:,:,vnorm_band_mask],vnorm_scaler)

            
            if transform == "log(1/R)":
                chunk_t[:,:,model_band_mask] = np.log(1/chunk_t[:,:,model_band_mask] )
                

            trait_mean,trait_std = apply_plsr_chunk(chunk_t[:,:,model_band_mask],coefficients,intercept)
            
            
            # Change no data pixel values
            trait_mean[chunk_nodata_mask] = args.nodata
            trait_std[chunk_nodata_mask] = args.nodata

            # Write trait estimate to file
            trait_geotiff = gdal.Open(dstFile, gdal.GA_Update)
            trait_geotiff.GetRasterBand(1).WriteArray(trait_mean, col_start, line_start)
            trait_geotiff.GetRasterBand(2).WriteArray(trait_std, col_start, line_start)
            trait_geotiff = None
Beispiel #6
0
def main():
    '''
    Perform in-memory trait estimation.
    '''
    
    parser = argparse.ArgumentParser(description = "In memory trait mapping tool.")
    parser.add_argument("-img", help="Input image pathname",required=True, type = str)
    parser.add_argument("--obs", help="Input observables pathname", required=False, type = str)
    parser.add_argument("--out", help="Output full corrected image", required=False, type = str)
    parser.add_argument("-od", help="Output directory for all resulting products", required=True, type = str)
    parser.add_argument("--brdf", help="Perform BRDF correction",type = str, default = '')
    parser.add_argument("--topo", help="Perform topographic correction", type = str, default = '')
    parser.add_argument("--mask", help="Image mask type to use", action='store_true')
    parser.add_argument("--mask_threshold", help="Mask threshold value", type = float)
    parser.add_argument("--rgbim", help="Export RGBI +Mask image.", action='store_true')
    parser.add_argument("-coeffs", help="Trait coefficients directory", required=True, type = str)
    args = parser.parse_args()

    
    traits = glob.glob("%s/*.json" % args.coeffs)
    
    #Load data objects memory
    if args.img.endswith(".h5"):
        hyObj = ht.openHDF(args.img,load_obs = True)
    else:
        hyObj = ht.openENVI(args.img)
    if (len(args.topo) != 0) | (len(args.brdf) != 0):
        hyObj.load_obs(args.obs)
    if not args.od.endswith("/"):
        args.od+="/"
    hyObj.create_bad_bands([[300,400],[1330,1430],[1800,1960],[2450,2600]])
    hyObj.load_data()
    
    # Generate mask
    if args.mask:
        ir = hyObj.get_wave(850)
        red = hyObj.get_wave(665)
        ndvi = (ir-red)/(ir+red)
        mask = (ndvi > args.mask_threshold) & (ir != hyObj.no_data)
        hyObj.mask = mask 
        del ir,red,ndvi
    else:
        hyObj.mask = np.ones((hyObj.lines,hyObj.columns)).astype(bool)
        print("Warning no mask specified, results may be unreliable!")

    # Generate cosine i and c1 image for topographic correction
    if len(args.topo) != 0:
        with open( args.topo) as json_file:  
            topo_coeffs = json.load(json_file)
            
        topo_coeffs['c'] = np.array(topo_coeffs['c'])   
        cos_i =  calc_cosine_i(hyObj.solar_zn, hyObj.solar_az, hyObj.azimuth , hyObj.slope)
        c1 = np.cos(hyObj.solar_zn) * np.cos( hyObj.slope)
           
    # Gernerate scattering kernel images for brdf correction
    if len(args.brdf) != 0:
        with open(args.brdf) as json_file:  
            brdf_coeffs = json.load(json_file)
            
        brdf_coeffs['fVol'] = np.array(brdf_coeffs['fVol'])
        brdf_coeffs['fGeo'] = np.array(brdf_coeffs['fGeo'])
        brdf_coeffs['fIso'] = np.array(brdf_coeffs['fIso'])
        
        k_vol = generate_volume_kernel(hyObj.solar_az,hyObj.solar_zn,hyObj.sensor_az,hyObj.sensor_zn, ross = brdf_coeffs['ross'])
        k_geom = generate_geom_kernel(hyObj.solar_az,hyObj.solar_zn,hyObj.sensor_az,hyObj.sensor_zn,li = brdf_coeffs['li'])
        k_vol_nadir = generate_volume_kernel(hyObj.solar_az,hyObj.solar_zn,hyObj.sensor_az,0, ross = brdf_coeffs['ross'])
        k_geom_nadir = generate_geom_kernel(hyObj.solar_az,hyObj.solar_zn,hyObj.sensor_az,0,li = brdf_coeffs['li'])

        
    #Cycle through the chunks and apply topo, brdf, vnorm,resampling and trait estimation steps
    print("Calculating values for %s traits....." % len(traits))

    # Cycle through trait models and gernerate resampler
    trait_waves_all = []
    trait_fwhm_all = []
    
    for i,trait in enumerate(traits):
        with open(trait) as json_file:  
            trait_model = json.load(json_file)
         
        # Check if wavelength units match
        if trait_model['wavelength_units'] == 'micrometers':
            trait_wave_scaler = 10**3
        else:
            trait_wave_scaler = 1    
        
        # Get list of wavelengths to compare against image wavelengths
        if len(trait_model['vector_norm_wavelengths']) == 0:
            trait_waves_all += list(np.array(trait_model['model_wavelengths'])*trait_wave_scaler)
        else:
            trait_waves_all += list(np.array(trait_model['vector_norm_wavelengths'])*trait_wave_scaler)
        
        trait_fwhm_all += list(np.array(trait_model['fwhm'])* trait_wave_scaler)        
             
    # List of all unique pairs of wavelengths and fwhm    
    trait_waves_fwhm = list(set([x for x in zip(trait_waves_all,trait_fwhm_all)]))
    trait_waves_fwhm.sort(key = lambda x: x[0])
    # Create a single set of resampling coefficients for all wavelength and fwhm combos
    resampling_coeffs = est_transform_matrix(hyObj.wavelengths[hyObj.bad_bands],[x for (x,y) in trait_waves_fwhm] ,hyObj.fwhm[hyObj.bad_bands],[y for (x,y) in trait_waves_fwhm],1)

    pixels_processed = 0
    iterator = hyObj.iterate(by = 'chunk',chunk_size = (100,100))

    while not iterator.complete:  
        chunk = iterator.read_next()  
        chunk_nodata_mask = chunk[:,:,1] == hyObj.no_data
        pixels_processed += chunk.shape[0]*chunk.shape[1]
        progbar(pixels_processed, hyObj.columns*hyObj.lines, 100)

        # Chunk Array indices
        line_start =iterator.current_line 
        line_end = iterator.current_line + chunk.shape[0]
        col_start = iterator.current_column
        col_end = iterator.current_column + chunk.shape[1]
        
        # Apply TOPO correction 
        if len(args.topo) != 0:
            cos_i_chunk = cos_i[line_start:line_end,col_start:col_end]
            c1_chunk = c1[line_start:line_end,col_start:col_end]
            correctionFactor = (c1_chunk[:,:,np.newaxis]+topo_coeffs['c'])/(cos_i_chunk[:,:,np.newaxis] + topo_coeffs['c'])
            chunk = chunk[:,:,hyObj.bad_bands]* correctionFactor
        else:
            chunk = chunk[:,:,hyObj.bad_bands] *1
        
        # Apply BRDF correction 
        if len(args.brdf) != 0:
            # Get scattering kernel for chunks
            k_vol_nadir_chunk = k_vol_nadir[line_start:line_end,col_start:col_end]
            k_geom_nadir_chunk = k_geom_nadir[line_start:line_end,col_start:col_end]
            k_vol_chunk = k_vol[line_start:line_end,col_start:col_end]
            k_geom_chunk = k_geom[line_start:line_end,col_start:col_end]
    
            # Apply brdf correction 
            # eq 5. Weyermann et al. IEEE-TGARS 2015)
            brdf = np.einsum('i,jk-> jki', brdf_coeffs['fVol'],k_vol_chunk) + np.einsum('i,jk-> jki', brdf_coeffs['fGeo'],k_geom_chunk)  + brdf_coeffs['fIso']
            brdf_nadir = np.einsum('i,jk-> jki', brdf_coeffs['fVol'],k_vol_nadir_chunk) + np.einsum('i,jk-> jki', brdf_coeffs['fGeo'],k_geom_nadir_chunk)  +brdf_coeffs['fIso']
            correctionFactor = brdf_nadir/brdf
            chunk= chunk* correctionFactor
        
        #Reassign no data values
        chunk[chunk_nodata_mask,:] = 0
        
        # Resample chunk 
        chunk_r = np.dot(chunk, resampling_coeffs) 
        
        # Export RGBIM image
        if args.rgbim:
            dstFile = args.od + os.path.splitext(os.path.basename(args.img))[0] + '_rgbim.tif'
            if line_start + col_start == 0:
                driver = gdal.GetDriverByName("GTIFF")
                tiff = driver.Create(dstFile,hyObj.columns,hyObj.lines,5,gdal.GDT_Float32)
                tiff.SetGeoTransform(hyObj.transform)
                tiff.SetProjection(hyObj.projection)
                for band in range(1,6):
                    tiff.GetRasterBand(band).SetNoDataValue(0)
                tiff.GetRasterBand(5).WriteArray(hyObj.mask)

                del tiff,driver
            # Write rgbi chunk
            rgbi_geotiff = gdal.Open(dstFile, gdal.GA_Update)
            for i,wave in enumerate([480,560,660,850],start=1):
                    band = hyObj.wave_to_band(wave)
                    rgbi_geotiff.GetRasterBand(i).WriteArray(chunk[:,:,band], col_start, line_start)
            rgbi_geotiff = None
        
        # Export BRDF and topo corrected image
        if args.out:
            if line_start + col_start == 0:
                output_name = args.od + os.path.splitext(os.path.basename(args.img))[0] + "_topo_brdf" 
                header_dict =hyObj.header_dict
                # Update header
                header_dict['wavelength']= header_dict['wavelength'][hyObj.bad_bands]
                header_dict['fwhm'] = header_dict['fwhm'][hyObj.bad_bands]
                header_dict['bbl'] = header_dict['bbl'][hyObj.bad_bands]
                header_dict['bands'] = hyObj.bad_bands.sum()
                writer = writeENVI(output_name,header_dict)
            writer.write_chunk(chunk,iterator.current_line,iterator.current_column)
            if iterator.complete:
                writer.close()
                
        for i,trait in enumerate(traits):
            dstFile = args.od + os.path.splitext(os.path.basename(args.img))[0] +"_" +os.path.splitext(os.path.basename(trait))[0] +".tif"
            
            # Trait estimation preparation
            if line_start + col_start == 0:
                
                with open(trait) as json_file:  
                    trait_model = json.load(json_file)
       
                intercept = np.array(trait_model['intercept'])
                coefficients = np.array(trait_model['coefficients'])
                transform = trait_model['transform']
                
                # Get list of wavelengths to compare against image wavelengths
                if len(trait_model['vector_norm_wavelengths']) == 0:
                    dst_waves = np.array(trait_model['model_wavelengths'])*trait_wave_scaler
                else:
                    dst_waves = np.array(trait_model['vector_norm_wavelengths'])*trait_wave_scaler
                
                dst_fwhm = np.array(trait_model['fwhm'])* trait_wave_scaler
                model_waves = np.array(trait_model['model_wavelengths'])* trait_wave_scaler
                model_fwhm = [dict(zip(dst_waves, dst_fwhm))[x] for x in model_waves]
                
                vnorm_band_mask = [x in zip(dst_waves,dst_fwhm) for x in trait_waves_fwhm]
                model_band_mask = [x in zip(model_waves,model_fwhm) for x in trait_waves_fwhm]
                
                if trait_model['vector_norm']:
                    vnorm_scaler = trait_model["vector_scaler"]
                else:
                    vnorm_scaler = None

                # Initialize trait dictionary
                if i == 0:
                    trait_dict = {}
                trait_dict[i] = [coefficients,intercept,trait_model['vector_norm'],vnorm_scaler,vnorm_band_mask,model_band_mask,transform]
        
                # Create geotiff driver
                driver = gdal.GetDriverByName("GTIFF")
                tiff = driver.Create(dstFile,hyObj.columns,hyObj.lines,2,gdal.GDT_Float32)
                tiff.SetGeoTransform(hyObj.transform)
                tiff.SetProjection(hyObj.projection)
                tiff.GetRasterBand(1).SetNoDataValue(0)
                tiff.GetRasterBand(2).SetNoDataValue(0)
                del tiff,driver
            
            coefficients,intercept,vnorm,vnorm_scaler,vnorm_band_mask,model_band_mask,transform = trait_dict[i]

            chunk_t =np.copy(chunk_r)

            if vnorm:            
                chunk_t[:,:,vnorm_band_mask] = vector_normalize_chunk(chunk_t[:,:,vnorm_band_mask],vnorm_scaler)
            
            if transform == "log(1/R)":
                chunk_t[:,:,model_band_mask] = np.log(1/chunk_t[:,:,model_band_mask] )

            trait_mean,trait_std = apply_plsr_chunk(chunk_t[:,:,model_band_mask],coefficients,intercept)
            
            # Change no data pixel values
            trait_mean[chunk_nodata_mask] = 0
            trait_std[chunk_nodata_mask] = 0

            # Write trait estimate to file
            trait_geotiff = gdal.Open(dstFile, gdal.GA_Update)
            trait_geotiff.GetRasterBand(1).WriteArray(trait_mean, col_start, line_start)
            trait_geotiff.GetRasterBand(2).WriteArray(trait_std, col_start, line_start)
            trait_geotiff = None