Ejemplo n.º 1
0
 def __init__(self, source, target=None, varlist=None, ignorelist=None, tmp=True, feedback=True):
   ''' Initialize processor and pass input and output datasets. '''
   # check varlist
   if varlist is None: varlist = source.variables.keys() # all source variables
   elif not isinstance(varlist,(list,tuple)): raise TypeError
   self.varlist = varlist # list of variable to be processed
   # ignore list (e.g. variables that will cause errors)
   if ignorelist is None: ignorelist = [] # an empty list
   elif not isinstance(ignorelist,(list,tuple)): raise TypeError
   self.ignorelist = ignorelist # list of variable *not* to be processed    
   # check input
   if not isinstance(source,Dataset): raise TypeError
   if isinstance(source,DatasetNetCDF) and not 'r' in source.mode: raise PermissionError
   self.input = source
   self.source = source
   # check output
   if target is not None:       
     if not isinstance(target,Dataset): raise TypeError
     if isinstance(target,DatasetNetCDF) and not 'w' in target.mode: raise PermissionError
   else:
     if not tmp: raise DatasetError, "Need target location, if temporary storage is disables (tmp=False)." 
   self.output = target
   # temporary dataset
   self.tmp = tmp
   if tmp: self.tmpput = Dataset(name='tmp', title='Temporary Dataset', varlist=[], atts={})
   else: self.tmpput = None
   # determine if temporary storage is used and assign target dataset
   if self.tmp: self.target = self.tmpput
   else: self.target = self.output 
   # whether or not to print status output
   self.feedback = feedback
Ejemplo n.º 2
0
def performExtraction(dataset,
                      mode,
                      stnfct,
                      dataargs,
                      loverwrite=False,
                      varlist=None,
                      lwrite=True,
                      lreturn=False,
                      ldebug=False,
                      lparallel=False,
                      pidstr='',
                      logger=None):
    ''' worker function to extract point data from gridded dataset '''
    # input checking
    if not isinstance(dataset, basestring): raise TypeError
    if not isinstance(dataargs, dict):
        raise TypeError  # all dataset arguments are kwargs
    if not callable(stnfct):
        raise TypeError  # function to load station dataset
    if lparallel:
        if not lwrite:
            raise IOError, 'In parallel mode we can only write to disk (i.e. lwrite = True).'
        if lreturn:
            raise IOError, 'Can not return datasets in parallel mode (i.e. lreturn = False).'

    # logging
    if logger is None:  # make new logger
        logger = logging.getLogger()  # new logger
        logger.addHandler(logging.StreamHandler())
    else:
        if isinstance(logger, basestring):
            logger = logging.getLogger(name=logger)  # connect to existing one
        elif not isinstance(logger, logging.Logger):
            raise TypeError, 'Expected logger ID/handle in logger KW; got {}'.format(
                str(logger))

    lclim = False
    lts = False
    if mode == 'climatology': lclim = True
    elif mode == 'time-series': lts = True
    else: raise NotImplementedError, "Unrecognized Mode: '{:s}'".format(mode)

    ## extract meta data from arguments
    dataargs, loadfct, srcage, datamsgstr = getMetaData(
        dataset, mode, dataargs)
    dataset_name = dataargs.dataset_name
    periodstr = dataargs.periodstr
    avgfolder = dataargs.avgfolder

    # load template dataset
    stndata = stnfct()  # load station dataset from function
    if not isinstance(stndata, Dataset): raise TypeError
    # N.B.: the loading function is necessary, because DataseNetCDF instances do not pickle well

    # get filename for target dataset and do some checks
    filename = getTargetFile(dataset=dataset,
                             mode=mode,
                             dataargs=dataargs,
                             lwrite=lwrite,
                             station=stndata.name)

    if ldebug: filename = 'test_' + filename
    if not os.path.exists(avgfolder):
        raise IOError, "Dataset folder '{:s}' does not exist!".format(
            avgfolder)
    lskip = False  # else just go ahead
    if lwrite:
        if lreturn:
            tmpfilename = filename  # no temporary file if dataset is passed on (can't rename the file while it is open!)
        else:
            if lparallel: tmppfx = 'tmp_exstns_{:s}_'.format(pidstr[1:-1])
            else: tmppfx = 'tmp_exstns_'.format(pidstr[1:-1])
            tmpfilename = tmppfx + filename
        filepath = avgfolder + filename
        tmpfilepath = avgfolder + tmpfilename
        if os.path.exists(filepath):
            if not loverwrite:
                age = datetime.fromtimestamp(os.path.getmtime(filepath))
                # if source file is newer than sink file or if sink file is a stub, recompute, otherwise skip
                if age > srcage and os.path.getsize(filepath) > 1e5:
                    lskip = True
                # N.B.: NetCDF files smaller than 100kB are usually incomplete header fragments from a previous crashed

    # depending on last modification time of file or overwrite setting, start computation, or skip
    if lskip:
        # print message
        skipmsg = "\n{:s}   >>>   Skipping: file '{:s}' in dataset '{:s}' already exists and is newer than source file.".format(
            pidstr, filename, dataset_name)
        skipmsg += "\n{:s}   >>>   ('{:s}')\n".format(pidstr, filepath)
        logger.info(skipmsg)
    else:

        ## actually load datasets
        source = loadfct()  # load source
        # check period
        if 'period' in source.atts and dataargs.periodstr != source.atts.period:  # a NetCDF attribute
            raise DateError, "Specifed period is inconsistent with netcdf records: '{:s}' != '{:s}'".format(
                periodstr, source.atts.period)

        # print message
        if lclim:
            opmsgstr = "Extracting '{:s}'-type Point Data from Climatology ({:s})".format(
                stndata.name, periodstr)
        elif lts:
            opmsgstr = "Extracting '{:s}'-type Point Data from Time-series".format(
                stndata.name)
        else:
            raise NotImplementedError, "Unrecognized Mode: '{:s}'".format(mode)
        # print feedback to logger
        logger.info(
            '\n{0:s}   ***   {1:^65s}   ***   \n{0:s}   ***   {2:^65s}   ***   \n'
            .format(pidstr, datamsgstr, opmsgstr))
        if not lparallel and ldebug: logger.info('\n' + str(source) + '\n')

        ## create new sink/target file
        # set attributes
        atts = source.atts.copy()
        atts[
            'period'] = dataargs.periodstr if dataargs.periodstr else 'time-series'
        atts['name'] = dataset_name
        atts['station'] = stndata.name
        atts['title'] = '{:s} (Stations) from {:s} {:s}'.format(
            stndata.title, dataset_name, mode.title())
        # make new dataset
        if lwrite:  # write to NetCDF file
            if os.path.exists(tmpfilepath):
                os.remove(tmpfilepath)  # remove old temp files
            sink = DatasetNetCDF(folder=avgfolder,
                                 filelist=[tmpfilename],
                                 atts=atts,
                                 mode='w')
        else:
            sink = Dataset(atts=atts)  # ony create dataset in memory

        # initialize processing
        CPU = CentralProcessingUnit(source,
                                    sink,
                                    varlist=varlist,
                                    tmp=False,
                                    feedback=ldebug)

        # extract data at station locations
        CPU.Extract(template=stndata, flush=True)
        # get results
        CPU.sync(flush=True)

        # print dataset
        if not lparallel and ldebug:
            logger.info('\n' + str(sink) + '\n')
        # write results to file
        if lwrite:
            sink.sync()
            writemsg = "\n{:s}   >>>   Writing to file '{:s}' in dataset {:s}".format(
                pidstr, filename, dataset_name)
            writemsg += "\n{:s}   >>>   ('{:s}')\n".format(pidstr, filepath)
            logger.info(writemsg)

            # rename file to proper name
            if not lreturn:
                sink.unload()
                sink.close()
                del sink  # destroy all references
                if os.path.exists(filepath):
                    os.remove(filepath)  # remove old file
                os.rename(tmpfilepath, filepath)
            # N.B.: there is no temporary file if the dataset is returned, because an open file can't be renamed

        # clean up and return
        source.unload()
        del source  #, CPU
        if lreturn:
            return sink  # return dataset for further use (netcdf file still open!)
        else:
            return 0  # "exit code"
Ejemplo n.º 3
0
def loadKister_StnTS(station=None,
                     well=None,
                     folder=None,
                     varlist='default',
                     varatts=None,
                     name='observations',
                     title=None,
                     basin=None,
                     start_date=None,
                     end_date=None,
                     sampling=None,
                     period=None,
                     date_range=None,
                     llastIncl=True,
                     WSC_station=None,
                     basin_list=None,
                     filenames=None,
                     time_axis='datetime',
                     scalefactors=None,
                     metadata=None,
                     lkgs=False,
                     ntime=None,
                     **kwargs):
    ''' load EnKF ensemble data as formatted GeoPy Dataset '''
    if folder and not os.path.exists(folder): raise IOError(folder)
    # default values
    if isinstance(varlist, str) and varlist == 'default':
        varlist = []
        if station: varlist += ['discharge']
        if well: varlist += ['head']
    if varatts is None: varatts = variable_attributes.copy()
    # figure out time axis
    if date_range: start_date, end_date, sampling = date_range
    time = timeAxis(start_date=start_date,
                    end_date=end_date,
                    sampling=sampling,
                    date_range=date_range,
                    time_axis=time_axis,
                    llastIncl=llastIncl,
                    ntime=ntime,
                    varatts=varatts)
    ntime = len(time)
    # load WSC station meta data
    pass
    # initialize Dataset
    dataset = Dataset(name=name,
                      title=title if title else name.title(),
                      atts=metadata)
    # load well data
    if 'head' in varlist:
        if not well: raise ArgumentError
        if folder:
            filepath = os.path.join(folder, well)  # default output folder
        else:
            filepath = station
        data = readKister(filepath=filepath,
                          period=(start_date, end_date),
                          resample=sampling,
                          lvalues=True)
        assert ntime == len(data), data.shape
        atts = varatts['head']
        dataset += Variable(atts=atts, data=data, axes=(time, ))
    # load discharge/hydrograph data
    if 'discharge' in varlist:
        if not station: raise ArgumentError
        if folder:
            filepath = os.path.join(folder, station)  # default output folder
        else:
            filepath = station
        data = readKister(filepath=filepath,
                          period=(start_date, end_date),
                          resample=sampling,
                          lvalues=True)
        assert ntime == len(data), data.shape
        atts = varatts['discharge']
        if lkgs:
            data *= 1000.
            if atts['units'] == 'm^3/s': atts['units'] = 'kg/s'
        dataset += Variable(atts=atts, data=data, axes=(time, ))
    # return formatted Dataset
    if scalefactors is not None and scalefactors != 1:
        raise NotImplementedError
    return dataset
Ejemplo n.º 4
0
def loadEnKF_StnTS(folder=None,
                   varlist='all',
                   varatts=None,
                   name='enkf',
                   title='EnKF',
                   basin=None,
                   start_date=None,
                   end_date=None,
                   sampling=None,
                   period=None,
                   date_range=None,
                   llastIncl=True,
                   WSC_station=None,
                   basin_list=None,
                   filenames=None,
                   prefix=None,
                   time_axis='datetime',
                   scalefactors=None,
                   metadata=None,
                   lkgs=False,
                   out_dir='out/',
                   yaml_file='../input_data/obs_meta.yaml',
                   lYAML=True,
                   nreal=None,
                   ntime=None,
                   **kwargs):
    ''' load EnKF ensemble data as formatted GeoPy Dataset '''
    out_folder = os.path.join(folder, 'out/')  # default output folder
    if not os.path.exists(out_folder): raise IOError(out_folder)
    # default values
    if isinstance(varlist, str) and varlist == 'hydro':
        varlist = Hydro.varlist
    elif isinstance(varlist, str) and varlist == 'obs':
        varlist = Obs.varlist
    elif isinstance(varlist, str) and varlist == 'all':
        varlist = Hydro.varlist + Obs.varlist
    elif not isinstance(varlist, (tuple, list)):
        raise TypeError(varlist)
    if varatts is None: varatts = variable_attributes.copy()
    varmap = {
        varatt['name']: enkf_name
        for enkf_name, varatt in list(varatts.items())
    }
    varlist = [varmap[var] for var in varlist]
    # load WSC station meta data
    pass
    # initialize Dataset
    dataset = Dataset(name=name,
                      title=title if title else name.title(),
                      atts=metadata)
    ensemble = None
    time = None
    observation = None
    # load observation/innovation data
    if any([var in Obs.atts for var in varlist]):
        # load data
        vardata = loadObs(varlist=[var for var in varlist if var in Obs.atts],
                          folder=out_folder,
                          lpandas=False)
        ntime, nobs, nreal = list(vardata.values())[0].shape
        # create Axes
        if time is None:
            # figure out time axis
            time = timeAxis(start_date=start_date,
                            end_date=end_date,
                            sampling=sampling,
                            date_range=date_range,
                            time_axis=time_axis,
                            llastIncl=llastIncl,
                            ntime=ntime,
                            varatts=varatts)
        elif len(time) != ntime:
            raise AxisError(time)
        if ensemble is None:
            # construct ensemble axis
            ensemble = Axis(atts=varatts['ensemble'],
                            coord=np.arange(1, nreal + 1))
        elif len(ensemble) != nreal:
            raise AxisError(ensemble)
        if observation is None:
            # construct ensemble axis
            observation = Axis(atts=varatts['observation'],
                               coord=np.arange(1, nobs + 1))
        elif len(observation) != nobs:
            raise AxisError(observation)
        # create variables
        for varname, data in list(vardata.items()):
            dataset += Variable(atts=varatts[varname],
                                data=data,
                                axes=(time, observation, ensemble))
        # load YAML data, if available
        if lYAML:
            # load YAML file
            yaml_path = os.path.join(out_folder, yaml_file)
            if not os.path.exists(yaml_path): raise IOError(yaml_path)
            with open(yaml_path, 'r') as yf:
                obs_meta = yaml.load(yf)
            if obs_meta is None: raise IOError(yaml_path)  # not a YAML file?
            # constant create variables
            for cvar, cval in list(obs_meta[0].items()):
                if isinstance(cval, str): dtype, missing = np.string_, ''
                elif isinstance(cval, (np.integer, int)):
                    dtype, missing = np.int_, 0
                elif isinstance(cval, (np.inexact, float)):
                    dtype, missing = np.float_, np.NaN
                else:
                    dtype = None  # skip
                if dtype:
                    data = np.asarray([
                        missing if obs[cvar] is None else obs[cvar]
                        for obs in obs_meta
                    ],
                                      dtype=dtype)
                    if cvar in varatts: atts = varatts[cvar]
                    else: atts = dict(name=cvar, units='')
                    dataset += Variable(atts=atts,
                                        data=data,
                                        axes=(observation, ))
    elif ntime is None:
        # try to infer time dimension from backup.info file
        backup_info = os.path.join(folder, 'backup.info')
        if os.path.exists(backup_info):
            with open(backup_info, 'r') as bf:
                ntime = int(bf.readline())
    # load discharge/hydrograph data
    if 'discharge' in varlist:
        data = loadHydro(folder=out_folder, nreal=nreal, ntime=ntime)
        ntime, nreal = data.shape
        if time is None:
            # figure out time axis
            time = timeAxis(start_date=start_date,
                            end_date=end_date,
                            sampling=sampling,
                            date_range=date_range,
                            time_axis=time_axis,
                            llastIncl=llastIncl,
                            ntime=ntime,
                            varatts=varatts)
        elif len(time) != ntime:
            raise AxisError(time)
        if ensemble is None:
            # construct ensemble axis
            ensemble = Axis(atts=varatts['ensemble'],
                            coord=np.arange(1, nreal + 1))
        elif len(ensemble) != nreal:
            raise AxisError(ensemble)
        atts = varatts['discharge']
        if lkgs:
            data *= 1000.
            if atts['units'] == 'm^3/s': atts['units'] = 'kg/s'
        dataset += Variable(atts=atts, data=data, axes=(time, ensemble))
    # return formatted Dataset
    if scalefactors is not None and scalefactors != 1:
        raise NotImplementedError
    return dataset
Ejemplo n.º 5
0
def performExport(dataset, mode, dataargs, expargs, bcargs, loverwrite=False, 
                  ldebug=False, lparallel=False, pidstr='', logger=None):
    ''' worker function to export ASCII rasters for a given dataset '''
    # input checking
    if not isinstance(dataset,basestring): raise TypeError
    if not isinstance(dataargs,dict): raise TypeError # all dataset arguments are kwargs 
    
    # logging
    if logger is None: # make new logger     
        logger = logging.getLogger() # new logger
        logger.addHandler(logging.StreamHandler())
    else:
        if isinstance(logger,basestring): 
            logger = logging.getLogger(name=logger) # connect to existing one
        elif not isinstance(logger,logging.Logger): 
            raise TypeError, 'Expected logger ID/handle in logger KW; got {}'.format(str(logger))
  
    ## extract meta data from arguments
    dataargs, loadfct, srcage, datamsgstr = getMetaData(dataset, mode, dataargs, lone=False)
    dataset_name = dataargs.dataset_name; periodstr = dataargs.periodstr; domain = dataargs.domain
    
    # figure out bias correction parameters
    if bcargs:
        bcargs = bcargs.copy() # first copy, then modify...
        bc_method = bcargs.pop('method',None)
        if bc_method is None: raise ArgumentError("Need to specify bias-correction method to use bias correction!")
        bc_obs = bcargs.pop('obs_dataset',None)
        if bc_obs is None: raise ArgumentError("Need to specify observational dataset to use bias correction!")
        bc_reference = bcargs.pop('reference',None)
        if bc_reference is None: # infer from experiment name
            if dataset_name[-5:] in ('-2050','-2100'): bc_reference = dataset_name[:-5] # cut of period indicator and hope for the best 
            else: bc_reference = dataset_name 
        bc_grid = bcargs.pop('grid',None)
        if bc_grid is None: bc_grid = dataargs.grid
        bc_domain = bcargs.pop('domain',None)
        if bc_domain is None: bc_domain = domain
        bc_varlist = bcargs.pop('varlist',None)
        bc_varmap = bcargs.pop('varmap',None)       
        bc_tag = bcargs.pop('tag',None) # an optional name extension/tag
        bc_pattern = bcargs.pop('file_pattern',None) # usually default in getPickleFile
        lgzip = bcargs.pop('lgzip',None) # if pickle is gzipped (None: auto-detect based on file name extension)
        # get name of pickle file (and folder)
        picklefolder = dataargs.avgfolder.replace(dataset_name,bc_reference)
        picklefile = getPickleFileName(method=bc_method, obs_name=bc_obs, gridstr=bc_grid, domain=bc_domain, 
                                       tag=bc_tag, pattern=bc_pattern)
        picklepath = '{:s}/{:s}'.format(picklefolder,picklefile)
        if lgzip:
            picklepath += '.gz' # add extension
            if not os.path.exists(picklepath): raise IOError(picklepath)
        elif lgzip is None:
            lgzip = False
            if not os.path.exists(picklepath):
                lgzip = True # assume gzipped file
                picklepath += '.gz' # try with extension...
                if not os.path.exists(picklepath): raise IOError(picklepath)
        elif not os.path.exists(picklepath): raise IOError(picklepath)
        pickleage = datetime.fromtimestamp(os.path.getmtime(picklepath))
        # determine age of pickle file and compare against source age
    else:
      bc_method = False 
      pickleage = srcage
    
    # parse export options
    expargs = expargs.copy() # first copy, then modify...
    lm3 = expargs.pop('lm3') # convert kg/m^2/s to m^3/m^2/s (water flux)
    expformat = expargs.pop('format') # needed to get FileFormat object
    exp_list= expargs.pop('exp_list') # this handled outside of export
    compute_list = expargs.pop('compute_list', []) # variables to be (re-)computed - by default all
    # initialize FileFormat class instance
    fileFormat = getFileFormat(expformat, bc_method=bc_method, **expargs)
    # get folder for target dataset and do some checks
    expname = '{:s}_d{:02d}'.format(dataset_name,domain) if domain else dataset_name
    expfolder = fileFormat.defineDataset(dataset=dataset, mode=mode, dataargs=dataargs, lwrite=True, ldebug=ldebug)
  
    # prepare destination for new dataset
    lskip = fileFormat.prepareDestination(srcage=max(srcage,pickleage), loverwrite=loverwrite)
  
    # depending on last modification time of file or overwrite setting, start computation, or skip
    if lskip:        
        # print message
        skipmsg =  "\n{:s}   >>>   Skipping: Format '{:s} for dataset '{:s}' already exists and is newer than source file.".format(pidstr,expformat,dataset_name)
        skipmsg += "\n{:s}   >>>   ('{:s}')\n".format(pidstr,expfolder)
        logger.info(skipmsg)              
    else:
            
      ## actually load datasets
      source = loadfct() # load source data
      # check period
      if 'period' in source.atts and dataargs.periodstr != source.atts.period: # a NetCDF attribute
          raise DateError, "Specifed period is inconsistent with netcdf records: '{:s}' != '{:s}'".format(periodstr,source.atts.period)
      
      # load BiasCorrection object from pickle
      if bc_method:      
          op = gzip.open if lgzip else open
          with op(picklepath, 'r') as filehandle:
              BC = pickle.load(filehandle) 
          # assemble logger entry
          bcmsgstr = "(performing bias-correction using {:s} from {:s} towards {:s})".format(BC.long_name,bc_reference,bc_obs)
      
      # print message
      if mode == 'climatology': opmsgstr = 'Exporting Climatology ({:s}) to {:s} Format'.format(periodstr, expformat)
      elif mode == 'time-series': opmsgstr = 'Exporting Time-series to {:s} Format'.format(expformat)
      elif mode[-5:] == '-mean': opmsgstr = 'Exporting {:s}-Mean ({:s}) to {:s} Format'.format(mode[:-5], periodstr, expformat)
      else: raise NotImplementedError, "Unrecognized Mode: '{:s}'".format(mode)        
      # print feedback to logger
      logmsg = '\n{0:s}   ***   {1:^65s}   ***   \n{0:s}   ***   {2:^65s}   ***   \n'.format(pidstr,datamsgstr,opmsgstr)
      if bc_method:
          logmsg += "{0:s}   ***   {1:^65s}   ***   \n".format(pidstr,bcmsgstr)
      logger.info(logmsg)
      if not lparallel and ldebug: logger.info('\n'+str(source)+'\n')
      
      # create GDAL-enabled target dataset
      sink = Dataset(axes=(source.xlon,source.ylat), name=expname, title=source.title, atts=source.atts.copy())
      addGDALtoDataset(dataset=sink, griddef=source.griddef)
      assert sink.gdal, sink
      
      # apply bias-correction
      if bc_method:
          source = BC.correct(source, asNC=False, varlist=bc_varlist, varmap=bc_varmap) # load bias-corrected variables into memory
        
      # N.B.: for variables that are not bias-corrected, data are not loaded immediately but on demand; this way 
      #       I/O and computing can be further disentangled and not all variables are always needed
      
      # compute intermediate variables, if necessary
      for varname in exp_list:
          variables = None # variable list
          var = None
          # (re-)compute variable, if desired...
          if varname in compute_list:
              if varname == 'precip': var = newvars.computeTotalPrecip(source)
              elif varname == 'waterflx': var = newvars.computeWaterFlux(source)
              elif varname == 'liqwatflx': var = newvars.computeLiquidWaterFlux(source)
              elif varname == 'netrad': var = newvars.computeNetRadiation(source, asVar=True)
              elif varname == 'netrad_bb': var = newvars.computeNetRadiation(source, asVar=True, lrad=False, name='netrad_bb')
              elif varname == 'netrad_bb0': var = newvars.computeNetRadiation(source, asVar=True, lrad=False, lA=False, name='netrad_bb0')
              elif varname == 'vapdef': var = newvars.computeVaporDeficit(source)
              elif varname in ('pet','pet_pm','petrad','petwnd') and 'pet' not in sink:
                  if 'petrad' in exp_list or 'petwnd' in exp_list:
                      variables = newvars.computePotEvapPM(source, lterms=True) # default; returns mutliple PET terms
                  else: var = newvars.computePotEvapPM(source, lterms=False) # returns only PET
              elif varname == 'pet_th': var = None # skip for now
                  #var = computePotEvapTh(source) # simplified formula (less prerequisites)
          # ... otherwise load from source file
          if var is None and variables is None and varname in source:
              var = source[varname].load() # load data (may not have to load all)
          #else: raise VariableError, "Unsupported Variable '{:s}'.".format(varname)
          # for now, skip variables that are None
          if var or variables:
              # handle lists as well
              if var and variables: raise VariableError, (var,variables)
              elif var: variables = (var,)
              for var in variables:
                  addGDALtoVar(var=var, griddef=sink.griddef)
                  if not var.gdal and isinstance(fileFormat,ASCII_raster):
                      raise GDALError, "Exporting to ASCII_raster format requires GDAL-enabled variables."
                  # add to new dataset
                  sink += var
      # convert units
      if lm3:
          for var in sink:
              if var.units == 'kg/m^2/s':
                  var /= 1000. # divide to get m^3/m^2/s
                  var.units = 'm^3/m^2/s' # update units
      
      # compute seasonal mean if we are in mean-mode
      if mode[-5:] == '-mean': 
          sink = sink.seasonalMean(season=mode[:-5], lclim=True)
          # N.B.: to remain consistent with other output modes, 
          #       we need to prevent renaming of the time axis
          sink = concatDatasets([sink,sink], axis='time', lensembleAxis=True)
          sink.squeeze() # we need the year-axis until now to distinguish constant fields; now remove
      
      # print dataset
      if not lparallel and ldebug:
          logger.info('\n'+str(sink)+'\n')
        
      # export new dataset to selected format
      fileFormat.exportDataset(sink)
        
      # write results to file
      writemsg =  "\n{:s}   >>>   Export of Dataset '{:s}' to Format '{:s}' complete.".format(pidstr,expname, expformat)
      writemsg += "\n{:s}   >>>   ('{:s}')\n".format(pidstr,expfolder)
      logger.info(writemsg)      
         
      # clean up and return
      source.unload(); #del source
      return 0 # "exit code"
Ejemplo n.º 6
0
 # imports
 from glob import glob
 from geodata.base import Dataset, Axis, Variable
 from geodata.netcdf import writeNetCDF
 
 
 # load list if well files and generate list of wells
 well_files = glob(os.path.join(data_folder,'W*.xlsx'))
 well_files.sort()
 wells = [os.path.basename(name[:-5]) for name in well_files]
 print(wells)
 
 # dataset
 time_ax = Axis(coord=np.arange(12*(period[1]-period[0]))+252, **varatts['time']) # origin: 1979-01
 well_ax = Axis(coord=np.arange(len(wells))+1, name='well', units='') 
 dataset = Dataset(name=conservation_authority, title=conservation_authority+' Observation Wells')
 # add meta data
 meta_dicts = [loadMetadata(well, conservation_authority=conservation_authority) for well in wells]
 for key in meta_dicts[0].keys():
     if key in varatts: atts = varatts[key]
     elif key.lower() in varatts: atts = varatts[key.lower()]
     else: atts = dict(name=key, units='')
     if atts['units']: data = np.asarray([wmd[key] for wmd in meta_dicts], dtype=np.float64)
     else: data = np.asarray([wmd[key] for wmd in meta_dicts])
     try: 
       dataset += Variable(data=data, axes=(well_ax,), **atts)
     except:
       pass  
 # add names
 dataset += Variable(data=wells, axes=(well_ax,), name='well_name', units='', 
                     atts=dict(long_name='Short Well Name'))
Ejemplo n.º 7
0
def performExport(dataset, mode, dataargs, expargs, loverwrite=False, 
                  ldebug=False, lparallel=False, pidstr='', logger=None):
  ''' worker function to perform regridding for a given dataset and target grid '''
  # input checking
  if not isinstance(dataset,basestring): raise TypeError
  if not isinstance(dataargs,dict): raise TypeError # all dataset arguments are kwargs 
  
  # logging
  if logger is None: # make new logger     
    logger = logging.getLogger() # new logger
    logger.addHandler(logging.StreamHandler())
  else:
    if isinstance(logger,basestring): 
      logger = logging.getLogger(name=logger) # connect to existing one
    elif not isinstance(logger,logging.Logger): 
      raise TypeError, 'Expected logger ID/handle in logger KW; got {}'.format(str(logger))

  ## extract meta data from arguments
  dataargs, loadfct, srcage, datamsgstr = getMetaData(dataset, mode, dataargs, lone=False)
  dataset_name = dataargs.dataset_name; periodstr = dataargs.periodstr; domain = dataargs.domain
  
  # parse export options
  expargs = expargs.copy() # first copy, then modify...
  lm3 = expargs.pop('lm3') # convert kg/m^2/s to m^3/m^2/s (water flux)
  expformat = expargs.pop('format') # needed to get FileFormat object
  varlist = expargs.pop('varlist') # this handled outside of export
  # initialize FileFormat class instance
  fileFormat = getFileFormat(expformat, **expargs)
  # get folder for target dataset and do some checks
  expname = '{:s}_d{:02d}'.format(dataset_name,domain) if domain else dataset_name
  expfolder = fileFormat.defineDataset(dataset=dataset, mode=mode, dataargs=dataargs, lwrite=True, ldebug=ldebug)

  # prepare destination for new dataset
  lskip = fileFormat.prepareDestination(srcage=srcage, loverwrite=loverwrite)
  
  # depending on last modification time of file or overwrite setting, start computation, or skip
  if lskip:        
    # print message
    skipmsg =  "\n{:s}   >>>   Skipping: Format '{:s} for dataset '{:s}' already exists and is newer than source file.".format(pidstr,expformat,dataset_name)
    skipmsg += "\n{:s}   >>>   ('{:s}')\n".format(pidstr,expfolder)
    logger.info(skipmsg)              
  else:
          
    ## actually load datasets
    source = loadfct() # load source data
    # check period
    if 'period' in source.atts and dataargs.periodstr != source.atts.period: # a NetCDF attribute
      raise DateError, "Specifed period is inconsistent with netcdf records: '{:s}' != '{:s}'".format(periodstr,source.atts.period)

    # print message
    if mode == 'climatology': opmsgstr = 'Exporting Climatology ({:s}) to {:s} Format'.format(periodstr, expformat)
    elif mode == 'time-series': opmsgstr = 'Exporting Time-series to {:s} Format'.format(expformat)
    elif mode[-5:] == '-mean': opmsgstr = 'Exporting {:s}-Mean ({:s}) to {:s} Format'.format(mode[:-5], periodstr, expformat)
    else: raise NotImplementedError, "Unrecognized Mode: '{:s}'".format(mode)        
    # print feedback to logger
    logger.info('\n{0:s}   ***   {1:^65s}   ***   \n{0:s}   ***   {2:^65s}   ***   \n'.format(pidstr,datamsgstr,opmsgstr))
    if not lparallel and ldebug: logger.info('\n'+str(source)+'\n')
    
    # create GDAL-enabled target dataset
    sink = Dataset(axes=(source.xlon,source.ylat), name=expname, title=source.title)
    addGDALtoDataset(dataset=sink, griddef=source.griddef)
    assert sink.gdal, sink
    
    # N.B.: data are not loaded immediately but on demand; this way I/O and computing are further
    #       disentangled and not all variables are always needed
    
    # Compute intermediate variables, if necessary
    for varname in varlist:
      vars = None # variable list
      if varname in source:
        var = source[varname].load() # load data (may not have to load all)
      else:
        var = None
        if varname == 'waterflx': var = newvars.computeWaterFlux(source)
        elif varname == 'liqwatflx': var = newvars.computeLiquidWaterFlux(source)
        elif varname == 'netrad': var = newvars.computeNetRadiation(source, asVar=True)
        elif varname == 'netrad_0': var = newvars.computeNetRadiation(source, asVar=True, lA=False, name='netrad_0')
        elif varname == 'netrad_bb': var = newvars.computeNetRadiation(source, asVar=True, lrad=False, name='netrad_bb')
        elif varname == 'vapdef': var = newvars.computeVaporDeficit(source)
        elif varname == 'pet' or varname == 'pet_pm':
          if 'petrad' in varlist or 'petwnd' in varlist:
            vars = newvars.computePotEvapPM(source, lterms=True) # default; returns mutliple PET terms
          else: var = newvars.computePotEvapPM(source, lterms=False) # returns only PET
        elif varname == 'pet_th': var = None # skip for now
          #var = computePotEvapTh(source) # simplified formula (less prerequisites)
        else: raise VariableError, "Unsupported Variable '{:s}'.".format(varname)
      # for now, skip variables that are None
      if var or vars:
        # handle lists as well
        if var and vars: raise VariableError, (var,vars)
        elif var: vars = (var,)
        for var in vars:
          addGDALtoVar(var=var, griddef=sink.griddef)
          if not var.gdal and isinstance(fileFormat,ASCII_raster):
            raise GDALError, "Exporting to ASCII_raster format requires GDAL-enabled variables."
          # add to new dataset
          sink += var
    # convert units
    if lm3:
      for var in sink:
        if var.units == 'kg/m^2/s':
          var /= 1000. # divide to get m^3/m^2/s
          var.units = 'm^3/m^2/s' # update units
    
    # compute seasonal mean if we are in mean-mode
    if mode[-5:] == '-mean': 
      sink = sink.seasonalMean(season=mode[:-5], taxatts=dict(name='time'))
      # N.B.: to remain consistent with other output modes, 
      #       we need to prevent renaming of the time axis
    
    # print dataset
    if not lparallel and ldebug:
      logger.info('\n'+str(sink)+'\n')
      
    # export new dataset to selected format
    fileFormat.exportDataset(sink)
      
    # write results to file
    writemsg =  "\n{:s}   >>>   Export of Dataset '{:s}' to Format '{:s}' complete.".format(pidstr,expname, expformat)
    writemsg += "\n{:s}   >>>   ('{:s}')\n".format(pidstr,expfolder)
    logger.info(writemsg)      
       
    # clean up and return
    source.unload(); #del source
    return 0 # "exit code"
Ejemplo n.º 8
0
def rasterDataset(name=None,
                  title=None,
                  vardefs=None,
                  axdefs=None,
                  atts=None,
                  projection=None,
                  griddef=None,
                  lgzip=None,
                  lgdal=True,
                  lmask=True,
                  fillValue=None,
                  lskipMissing=True,
                  lgeolocator=True,
                  file_pattern=None,
                  lfeedback=True,
                  **kwargs):
    ''' function to load a set of variables that are stored in raster format in a systematic directory tree into a Dataset
        Variables and Axis are defined as follows:
          vardefs[varname] = dict(name=string, units=string, axes=tuple of strings, atts=dict, plot=dict, dtype=np.dtype, fillValue=value)
          axdefs[axname]   = dict(name=string, units=string, atts=dict, coord=array or list) or None
        The path to raster files is constructed as variable_pattern+axes_pattern, where axes_pattern is defined through the axes, 
        (as in rasterVarialbe) and variable_pattern takes the special keywords VAR, which is the variable key in vardefs.
    '''

    ## prepare input data and axes
    if griddef:
        xlon, ylat = griddef.xlon, griddef.ylat
        if projection is None:
            projection = griddef.projection
        elif projection != griddef.projection:
            raise ArgumentError("Conflicting projection and GridDef!")
        geotransform = griddef.geotransform
        isProjected = griddef.isProjected
    else:
        xlon = ylat = geotransform = None
        isProjected = False if projection is None else True
    # construct axes dict
    axes = dict()
    for axname, axdef in axdefs.items():
        assert 'coord' in axdef, axdef
        assert ('name' in axdef and 'units' in axdef) or 'atts' in axdef, axdef
        if axdef is None:
            axes[axname] = None
        else:
            ax = Axis(**axdef)
            axes[ax.name] = ax
    # check for map Axis
    if isProjected:
        if 'x' not in axes: axes['x'] = xlon
        if 'y' not in axes: axes['y'] = ylat
    else:
        if 'lon' not in axes: axes['lon'] = xlon
        if 'lat' not in axes: axes['lat'] = ylat

    ## load raster data into Variable objects
    varlist = []
    for varname, vardef in vardefs.items():
        # check definitions
        assert 'axes' in vardef and 'dtype' in vardef, vardef
        assert ('name' in vardef
                and 'units' in vardef) or 'atts' in vardef, vardef
        # determine relevant axes
        vardef = vardef.copy()
        axes_list = [
            None if ax is None else axes[ax] for ax in vardef.pop('axes')
        ]
        # define path parameters (with varname)
        path_params = vardef.pop('path_params', None)
        path_params = dict() if path_params is None else path_params.copy()
        if 'VAR' not in path_params:
            path_params['VAR'] = varname  # a special key
        # add kwargs and relevant axis indices
        relaxes = [ax.name for ax in axes_list
                   if ax is not None]  # relevant axes
        for key, value in kwargs.items():
            if key not in axes or key in relaxes:
                vardef[key] = value
        # create Variable object
        var = rasterVariable(projection=projection,
                             griddef=griddef,
                             file_pattern=file_pattern,
                             lgzip=lgzip,
                             lgdal=lgdal,
                             lmask=lmask,
                             lskipMissing=lskipMissing,
                             axes=axes_list,
                             path_params=path_params,
                             lfeedback=lfeedback,
                             **vardef)
        # vardef components: name, units, atts, plot, dtype, fillValue
        varlist.append(var)
        # check that map axes are correct
        for ax in var.xlon, var.ylat:
            if axes[ax.name] is None: axes[ax.name] = ax
            elif axes[ax.name] != ax:
                raise AxisError("{} axes are incompatible.".format(ax.name))
        if griddef is None: griddef = var.griddef
        elif griddef != var.griddef:
            raise AxisError("GridDefs are inconsistent.")
        if geotransform is None: geotransform = var.geotransform
        elif geotransform != var.geotransform:
            raise AxisError(
                "Conflicting geotransform (from Variable) and GridDef!\n {} != {}"
                .format(var.geotransform, geotransform))

    ## create Dataset
    # create dataset
    dataset = Dataset(name=name,
                      title=title,
                      varlist=varlist,
                      axes=axes,
                      atts=atts)
    # add GDAL functionality
    dataset = addGDALtoDataset(dataset,
                               griddef=griddef,
                               projection=projection,
                               geotransform=geotransform,
                               gridfolder=None,
                               lwrap360=None,
                               geolocator=lgeolocator,
                               lforce=False)
    # N.B.: for some reason we also need to pass the geotransform, otherwise it is recomputed internally and some consistency
    #       checks fail due to machine-precision differences

    # return GDAL-enabled Dataset
    return dataset
Ejemplo n.º 9
0
 def Regrid(self, griddef=None, projection=None, geotransform=None, size=None, xlon=None, ylat=None, 
            lmask=True, int_interp=None, float_interp=None, **kwargs):
   ''' Setup climatology and start computation; calls processClimatology. '''
   # make temporary gdal dataset
   if self.source is self.target:
     if self.tmp: assert self.source == self.tmpput and self.target == self.tmpput
     # the operation can not be performed "in-place"!
     self.target = Dataset(name='tmptoo', title='Temporary target dataset for non-in-place operations', varlist=[], atts={})
     ltmptoo = True
   else: ltmptoo = False 
   # make sure the target dataset is a GDAL-enabled dataset
   if 'gdal' in self.target.__dict__: 
     # gdal info alread present      
     if griddef is not None or projection is not None or geotransform is not None: 
       raise AttributeError, "Target Dataset '%s' is already GDAL enabled - cannot overwrite settings!"%self.target.name
     if self.target.xlon is None: raise GDALError, "Map axis 'xlon' not found!"
     if self.target.ylat is None: raise GDALError, "Map axis 'ylat' not found!"
     xlon = self.target.xlon; ylat = self.target.ylat
   else:
     # need to set GDAL parameters
     if self.tmp and 'gdal' in self.output.__dict__:
       # transfer gdal settings from output to temporary dataset 
       assert self.target is not self.output 
       projection = self.output.projection; geotransform = self.output.geotransform
       xlon = self.output.xlon; ylat = self.output.ylat
     else:
       # figure out grid definition from input 
       if griddef is None: 
         griddef = GridDefinition(projection=projection, geotransform=geotransform, size=size, xlon=xlon, ylat=ylat)
       # pass arguments through GridDefinition, if not provided
       projection=griddef.projection; geotransform=griddef.geotransform
       xlon=griddef.xlon; ylat=griddef.ylat                     
     # apply GDAL settings target dataset 
     for ax in (xlon,ylat): self.target.addAxis(ax, loverwrite=True) # i.e. replace if already present
     self.target = addGDALtoDataset(self.target, projection=projection, geotransform=geotransform)
   # use these map axes
   xlon = self.target.xlon; ylat = self.target.ylat
   assert isinstance(xlon,Axis) and isinstance(ylat,Axis)
   # determine source dataset grid definition
   if self.source.griddef is None:  
     srcgrd = GridDefinition(projection=self.source.projection, geotransform=self.source.geotransform, 
                             size=self.source.mapSize, xlon=self.source.xlon, ylat=self.source.ylat)
   else: srcgrd = self.source.griddef
   srcres = srcgrd.scale; tgtres = griddef.scale
   # determine if shift is necessary to insure correct wrapping
   if not srcgrd.isProjected and not griddef.isProjected:
     lwrapSrc = srcgrd.wrap360
     lwrapTgt = griddef.wrap360
     # check grids
     for grd in (srcgrd,griddef):
       if grd.wrap360:            
         assert grd.geotransform[0] + grd.geotransform[1]*(len(grd.xlon)-1) > 180        
         assert np.round(grd.geotransform[1]*len(grd.xlon), decimals=2) == 360 # require 360 deg. to some accuracy... 
         assert any( grd.xlon.getArray() > 180 ) # need to wrap around
         assert all( grd.xlon.getArray() >= 0 )
         assert all( grd.xlon.getArray() <= 360 )
       else:
         assert grd.geotransform[0] + grd.geotransform[1]*(len(grd.xlon)-1) < 180
         assert all( grd.xlon.getArray() >= -180 )
         assert all( grd.xlon.getArray() <= 180 )  
   else: 
     lwrapSrc = False # no need to shift, if a projected grid is involved!
     lwrapTgt = False # no need to shift, if a projected grid is involved!
   # determine GDAL interpolation
   if int_interp is None: int_interp = gdalInterp('nearest')
   else: int_interp = gdalInterp(int_interp)
   if float_interp is None:
     if srcres < tgtres: float_interp = gdalInterp('convolution') # down-sampling: 'convolution'
     else: float_interp = gdalInterp('cubicspline') # up-sampling
   else: float_interp = gdalInterp(float_interp)      
   # prepare function call    
   function = functools.partial(self.processRegrid, ylat=ylat, xlon=xlon, lwrapSrc=lwrapSrc, lwrapTgt=lwrapTgt, # already set parameters
                                lmask=lmask, int_interp=int_interp, float_interp=float_interp)
   # start process
   if self.feedback: print('\n   +++   processing regridding   +++   ') 
   self.process(function, **kwargs) # currently 'flush' is the only kwarg
   # now make sure we have a GDAL dataset!
   self.target = addGDALtoDataset(self.target, griddef=griddef)
   if self.feedback: print('\n')
   if self.tmp: self.tmpput = self.target
   if ltmptoo: assert self.tmpput.name == 'tmptoo' # set above, when temp. dataset is created    
Ejemplo n.º 10
0
  def Extract(self, template=None, stnax=None, xlon=None, ylat=None, laltcorr=True, **kwargs):
    ''' Extract station data points from gridded datasets; calls processExtract. 
        A station dataset can be passed as template (must have station coordinates. '''
    if not self.source.gdal: raise DatasetError, "Source dataset must be GDAL enabled! {:s} is not.".format(self.source.name)
    if template is None: raise NotImplementedError
    elif isinstance(template, Dataset):
      if not template.hasAxis('station'): raise DatasetError, "Template station dataset needs to have a station axis."
      if not ( (template.hasVariable('lat') or template.hasVariable('stn_lat')) and 
               (template.hasVariable('lon') or template.hasVariable('stn_lon')) ): 
        raise DatasetError, "Template station dataset needs to have lat/lon arrays for the stations."      
    else: raise TypeError
    # make temporary dataset
    if self.source is self.target:
      if self.tmp: assert self.source == self.tmpput and self.target == self.tmpput
      # the operation can not be performed "in-place"!
      self.target = Dataset(name='tmptoo', title='Temporary target dataset for non-in-place operations', varlist=[], atts={})
      ltmptoo = True
    else: ltmptoo = False
    src = self.source; tgt = self.target # short-cuts 
    # determine source dataset grid definition
    if src.griddef is None:  
      srcgrd = GridDefinition(projection=self.source.projection, geotransform=self.source.geotransform, 
                              size=self.source.mapSize, xlon=self.source.xlon, ylat=self.source.ylat)
    else: srcgrd = src.griddef
    # figure out horizontal axes (will be replaced with station axis)
    if isinstance(xlon,Axis): 
      if not src.hasAxis(xlon, check=True): raise DatasetError
    elif isinstance(xlon,basestring): xlon = src.getAxis(xlon)
    else: xlon = src.x if srcgrd.isProjected else src.lon
    if isinstance(ylat,Axis):
      if not src.hasAxis(ylat, check=True): raise DatasetError
    elif isinstance(ylat,basestring): ylat = src.getAxis(ylat)
    else: ylat = src.y if srcgrd.isProjected else src.lat
    if stnax: # not in source dataset!
      if src.hasAxis(stnax, check=True): raise DatasetError, "Source dataset must not have a 'station' axis!"
    elif template: stnax = template.station # station axis
    else: raise ArgumentError, "A station axis needs to be supplied." 
    assert isinstance(xlon,Axis) and isinstance(ylat,Axis) and isinstance(stnax,Axis)
    # transform to dataset-native coordinate system
    if template: 
      if template.hasVariable('lat'): lats = template.lat.getArray()
      else: lats = template.stn_lat.getArray()
      if template.hasVariable('lon'): lons = template.lon.getArray()
      else: lons = template.stn_lon.getArray()
    else: raise NotImplementedError, "Cannot extract station data without a station template Dataset"
    # adjust longitudes
    if srcgrd.isProjected:
      if lons.max() > 180.: lons = np.where(lons > 180., 360.-lons, lons)
      # reproject coordinate
      latlon = osr.SpatialReference() 
      latlon.SetWellKnownGeogCS('WGS84') # a normal lat/lon coordinate system
      tx = osr.CoordinateTransformation(latlon,srcgrd.projection)
      xs = []; ys = [] 
      for i in xrange(len(lons)):
        x,y,z = tx.TransformPoint(lons[i].astype(np.float64),lats[i].astype(np.float64))
        xs.append(x); ys.append(y); del z
      lons = np.array(xs); lats = np.array(ys)
      #lons,lats = tx.TransformPoints(lons,lats) # doesn't seem to work...
    else:
      if lons.min() < 0. and xlon.coord.max() > 180.: lons = np.where(lons < 0., lons + 360., lons)
      elif lons.max() > 180. and xlon.coord.min() < 0.: lons = np.where(lons > 180., 360.-lons, lons)
      else: pass # source and template do not conflict
    # generate index list
    ixlon = []; iylat = []; istn = []; zs_err = [] # also record elevation error
    lzs = src.hasVariable('zs')
    lstnzs = template.hasVariable('zs') or  template.hasVariable('stn_zs')
    if laltcorr and lzs and lstnzs:
      if src.zs.ndim > 2: src.zs = src.zs(time=0, lidx=True) # first time-slice (for CESM)
      if src.zs.ndim != 2 or not src.gdal or src.zs.units != 'm': raise VariableError
      # consider altidue of surrounding points as well      
      zs = src.zs.getArray(unmask=True,fillValue=-300)
      if template.hasVariable('zs'): stn_zs = template.zs.getArray(unmask=True,fillValue=-300)
      else: stn_zs = template.stn_zs.getArray(unmask=True,fillValue=-300)
      if src.zs.axisIndex(xlon.name) == 0: zs.transpose() # assuming lat,lon or y,x order is more common
      ye,xe = zs.shape # assuming order lat,lon or y,x
      xe -= 1; ye -= 1 # last valid index, not length
      for n,lon,lat in zip(xrange(len(stnax)),lons,lats):
        ip = xlon.getIndex(lon, mode='left', outOfBounds=True)
        jp = ylat.getIndex(lat, mode='left', outOfBounds=True)
        if ip is not None and jp is not None:
          # find neighboring point with smallest altitude error 
#           ip = im+1 if im < xe else im  
#           jp = jm+1 if jm < ye else jm
          im = ip-1 if ip > 0 else ip  
          jm = jp-1 if jp > 0 else jp
          zdiff = np.Infinity # initialize, so that it triggers at least once
          # check four closest grid points
          for i in im,ip:
            for j in jm,jp:
              ze = zs[j,i]-stn_zs[n]
              zd = np.abs(ze) # compute elevation error
              if zd < zdiff: ii,jj,zdiff,zerr = i,j,zd,ze # preliminary selection, triggers at least once               
          ixlon.append(ii); iylat.append(jj); istn.append(n); zs_err.append(zerr) # final selection          
    else: 
      # just choose horizontally closest point 
      for n,lon,lat in zip(xrange(len(stnax)),lons,lats):
        i = xlon.getIndex(lon, mode='closest', outOfBounds=True)
        j = ylat.getIndex(lat, mode='closest', outOfBounds=True)
        if i is not None and j is not None: 
          if lzs: # compute elevation error
            zs_err.append(zs[j,i]-stn_zs[n])          
          ixlon.append(i); iylat.append(j); istn.append(n)
    # N.B.: it is necessary to append, because we don't know the number of valid points
    ixlon = np.array(ixlon); iylat = np.array(iylat); istn = np.array(istn); zs_err = np.array(zs_err)
    # prepare target dataset
    # N.B.: attributes should already be set in target dataset (by caller module)
    #       we are also assuming the new dataset has no axes yet
    assert len(tgt.axes) == 0
    # add axes from source data
    for axname,ax in src.axes.iteritems():
      if axname not in (xlon.name,ylat.name):
        tgt.addAxis(ax, asNC=True, copy=True)
    # add station axis (trim to valid coordinates)
    newstnax = stnax.copy(coord=stnax.coord[istn]) # same but with trimmed coordinate array
    tgt.addAxis(newstnax, asNC=True, copy=True) # already new copy
    # create variable for elevation error
    if lzs:
      assert len(zs_err) > 0
      zs_err = Variable(name='zs_err', units='m', data=zs_err, axes=(newstnax,),
                        atts=dict(long_name='Station Elevation Error'))
      tgt.addVariable(zs_err, asNC=True, copy=True); del zs_err # need to copy to make NC var
    # add a bunch of other variables with station meta data
    for var in template.variables.itervalues():
      if var.ndim == 1 and var.hasAxis(stnax): # station attributes
        if var.name[-4:] != '_len' or var.name == 'stn_rec_len': # exclude certain attributes
          newvar = var.copy(data=var.getArray()[istn], axes=(newstnax,))
          if newvar.name[:4] != 'stn_' and newvar.name[:8] != 'station_' and newvar.name[:8] != 'cluster_': 
            newvar.name = 'stn_'+newvar.name # copy cluster_* as they are!
          # N.B.: we need to rename, or name collisions will happen! 
          tgt.addVariable(newvar, asNC=True, copy=True); del newvar # need to copy to make NC var
    # save all the meta data
    tgt.sync()
    # prepare function call    
    function = functools.partial(self.processExtract, ixlon=ixlon, iylat=iylat, ylat=ylat, xlon=xlon, stnax=stnax) # already set parameters
    # start process
    if self.feedback: print('\n   +++   processing point-data extraction   +++   ') 
    self.process(function, **kwargs) # currently 'flush' is the only kwarg
    if self.feedback: print('\n')
    if self.tmp: self.tmpput = self.target
    if ltmptoo: assert self.tmpput.name == 'tmptoo' # set above, when temp. dataset is created    
Ejemplo n.º 11
0
class CentralProcessingUnit(object):
  
  def __init__(self, source, target=None, varlist=None, ignorelist=None, tmp=True, feedback=True):
    ''' Initialize processor and pass input and output datasets. '''
    # check varlist
    if varlist is None: varlist = source.variables.keys() # all source variables
    elif not isinstance(varlist,(list,tuple)): raise TypeError
    self.varlist = varlist # list of variable to be processed
    # ignore list (e.g. variables that will cause errors)
    if ignorelist is None: ignorelist = [] # an empty list
    elif not isinstance(ignorelist,(list,tuple)): raise TypeError
    self.ignorelist = ignorelist # list of variable *not* to be processed    
    # check input
    if not isinstance(source,Dataset): raise TypeError
    if isinstance(source,DatasetNetCDF) and not 'r' in source.mode: raise PermissionError
    self.input = source
    self.source = source
    # check output
    if target is not None:       
      if not isinstance(target,Dataset): raise TypeError
      if isinstance(target,DatasetNetCDF) and not 'w' in target.mode: raise PermissionError
    else:
      if not tmp: raise DatasetError, "Need target location, if temporary storage is disables (tmp=False)." 
    self.output = target
    # temporary dataset
    self.tmp = tmp
    if tmp: self.tmpput = Dataset(name='tmp', title='Temporary Dataset', varlist=[], atts={})
    else: self.tmpput = None
    # determine if temporary storage is used and assign target dataset
    if self.tmp: self.target = self.tmpput
    else: self.target = self.output 
    # whether or not to print status output
    self.feedback = feedback
        
  def getTmp(self, asNC=False, filename=None, deepcopy=False, **kwargs):
    ''' Get a copy of the temporary data in dataset format. '''
    if not self.tmp: raise DatasetError
    # make new dataset (name and title should transfer in atts dict)
    if asNC:
      if not isinstance(filename,basestring): raise TypeError
      writeData = kwargs.pop('writeData',False)
      ncformat = kwargs.pop('ncformat','NETCDF4')
      zlib = kwargs.pop('zlib',True)
      dataset = asDatasetNC(self.tmpput, ncfile=filename, mode='wr', deepcopy=deepcopy, 
                            writeData=writeData, ncformat=ncformat, zlib=zlib, **kwargs)
    else:
      dataset = self.tmpput.copy(varsdeep=deepcopy, atts=self.input.atts.copy(), **kwargs)
    # return dataset
    return dataset
  
  def sync(self, varlist=None, flush=False, gdal=True, copydata=True):
    ''' Transfer contents of temporary storage to output/target dataset. '''
    if not isinstance(self.output,Dataset): raise DatasetError, "Cannot sync without target Dataset!"
    if self.tmp:
      if varlist is None: varlist = self.tmpput.variables.keys()  
      for varname in varlist:
        if varname in self.tmpput.variables:
          var = self.tmpput.variables[varname]
          self.output.addVariable(var, loverwrite=True, deepcopy=copydata)
          # N.B.: without copydata/deepcopy, only the variable header is created but no data is written
          if flush: var.unload() # remove unnecessary references (unlink data)
      if gdal and 'gdal' in self.tmpput.__dict__: 
        if self.tmpput.gdal: 
          projection = self.tmpput.projection; geotransform = self.tmpput.geotransform
          #xlon = self.tmpput.xlon; ylat = self.tmpput.ylat 
        else: 
          projection=None; geotransform=None; #xlon = None; ylat = None 
        self.output = addGDALtoDataset(self.output, projection=projection, geotransform=geotransform)
#           self.source = self.output # future operations will write to the output dataset directly
#           self.target = self.output # future operations will write to the output dataset directly                     
        
  def writeNetCDF(self, filename=None, folder=None, ncformat='NETCDF4', zlib=True, writeData=True, close=False, flush=False):
    ''' Write current temporary storage to a NetCDF file. '''
    if self.tmp:
      if not isinstance(filename,basestring): raise TypeError
      if folder is not None: filename = folder + filename       
      output = writeNetCDF(self.tmpput, filename, ncformat=ncformat, zlib=zlib, writeData=writeData, close=False)
      if flush: self.tmpput.unload()
      if self.feedback: print('\nOutput written to {0:s}\n'.format(filename))
    else: 
      self.output.sync()
      output = self.output.dataset # get (primary) NetCDF file
      if self.feedback: print('\nSynchronized dataset {0:s} with temporary storage.\n'.format(output.name))
    # flush?
    if flush: self.output.unload()      
    # close file or return file handle
    if close: output.close()
    else: return output

  def process(self, function, flush=False):
    ''' This method applies the desired operation/function to each variable in varlist. '''
    if flush: # this function is to save RAM by flushing results to disk immediately
      if not isinstance(self.output,DatasetNetCDF):
        raise ProcessError, "Flush can only be used with NetCDF Datasets (and not with temporary storage)."
      if self.tmp: # flush requires output to be target
        if self.source.gdal and not self.tmpput.gdal:
          self.tmpput = addGDALtoDataset(self.tmpput, projection=self.source.projection, geotransform=self.source.geotransform)
        self.source = self.tmpput
        self.target = self.output
        self.tmp = False # not using temporary storage anymore
    # loop over input variables
    for varname in self.varlist:
      # check agaisnt ignore list
      if varname not in self.ignorelist: 
        # check if variable already exists
        if self.target.hasVariable(varname):
          # "in-place" operations
          var = self.target.variables[varname]         
          newvar = function(var) # perform actual processing
          if newvar.ndim != var.ndim or newvar.shape != var.shape: raise VariableError
          if newvar is not var: self.target.replaceVariable(var,newvar)
        elif self.source.hasVariable(varname):        
          var = self.source.variables[varname]
          ldata = var.data # whether data was pre-loaded 
          # perform operation from source and copy results to target
          newvar = function(var) # perform actual processing
          if not ldata: var.unload() # if it was already loaded, don't unload        
          self.target.addVariable(newvar, copy=True) # copy=True allows recasting as, e.g., a NC variable
        else:
          raise DatasetError, "Variable '%s' not found in input dataset."%varname
        assert varname == newvar.name
        # flush data to disk immediately      
        if flush: 
          self.output.variables[varname].unload() # again, free memory
        newvar.unload(); del var, newvar # free space; already added to new dataset
    # after everything is said and done:
    self.source = self.target # set target to source for next time
    
    
  ## functions (or function pairs, rather) that perform operations on the data
  # every function pair needs to have a setup function and a processing function
  # the former sets up the target dataset and the latter operates on the variables
  
  # function pair to average data over a given collection of shapes      
  def ShapeAverage(self, shape_dict=None, shape_name=None, shpax=None, xlon=None, ylat=None, 
                   memory=500, **kwargs):
    ''' Average over a limited area of a gridded datasets; calls processAverageShape. 
        A dictionary of NamedShape objects is expected to define the averaging areas. 
        'memory' controls the garbage collection interval and approximately corresponds 
        to MB in temporary (it does not include loading the variable into RAM, though). '''
    if not self.source.gdal: raise DatasetError, "Source dataset must be GDAL enabled! {:s} is not.".format(self.source.name)
    if not isinstance(shape_dict,OrderedDict): raise TypeError
    if not all(isinstance(shape,NamedShape) for shape in shape_dict.itervalues()): raise TypeError
    # make temporary dataset
    if self.source is self.target:
      if self.tmp: assert self.source == self.tmpput and self.target == self.tmpput
      # the operation can not be performed "in-place"!
      self.target = Dataset(name='tmptoo', title='Temporary target dataset for non-in-place operations', varlist=[], atts={})
      ltmptoo = True
    else: ltmptoo = False
    src = self.source; tgt = self.target # short-cuts 
    # determine source dataset grid definition
    if src.griddef is None:  
      srcgrd = GridDefinition(projection=self.source.projection, geotransform=self.source.geotransform, 
                              size=self.source.mapSize, xlon=self.source.xlon, ylat=self.source.ylat)
    else: srcgrd = src.griddef
    # figure out horizontal axes (will be replaced with station axis)
    if isinstance(xlon,Axis): 
      if not src.hasAxis(xlon, check=True): raise DatasetError
    elif isinstance(xlon,basestring): xlon = src.getAxis(xlon)
    else: xlon = src.x if srcgrd.isProjected else src.lon
    if isinstance(ylat,Axis):
      if not src.hasAxis(ylat, check=True): raise DatasetError
    elif isinstance(ylat,basestring): ylat = src.getAxis(ylat)
    else: ylat = src.y if srcgrd.isProjected else src.lat
    # check/create shapes axis
    if shpax: # not in source dataset!
      # if shape axis supplied
      if src.hasAxis(shpax, check=True): raise DatasetError, "Source dataset must not have a 'shape' axis!"
      if len(shpax) != len(shape_dict): raise AxisError
    else:
      # creat shape axis, if not supplied
      shpatts = dict(name='shape', long_name='Ordinal Number of Shape', units='#')
      shpax = Axis(coord=np.arange(1,len(shape_dict)+1), atts=shpatts) # starting at 1
    assert isinstance(xlon,Axis) and isinstance(ylat,Axis) and isinstance(shpax,Axis)
    # prepare target dataset
    # N.B.: attributes should already be set in target dataset (by caller module)
    #       we are also assuming the new dataset has no axes yet
    assert len(tgt.axes) == 0
    # add station axis (trim to valid coordinates)
    tgt.addAxis(shpax, asNC=True, copy=True) # already new copy
    # add axes from source data
    for axname,ax in src.axes.iteritems():
      if axname not in (xlon.name,ylat.name):
        tgt.addAxis(ax, asNC=True, copy=True)
    # add shape names
    shape_names = [shape.name for shape in shape_dict.itervalues()] # can construct Variable from list!
    atts = dict(name='shape_name', long_name='Name of Shape', units='')
    tgt.addVariable(Variable(data=shape_names, axes=(shpax,), atts=atts), asNC=True, copy=True)
    # add proper names
    shape_long_names = [shape.long_name for shape in shape_dict.itervalues()] # can construct Variable from list!
    atts = dict(name='shp_long_name', long_name='Proper Name of Shape', units='')
    tgt.addVariable(Variable(data=shape_long_names, axes=(shpax,), atts=atts), asNC=True, copy=True)    
    # add shape category
    shape_type = [shape.shapetype for shape in shape_dict.itervalues()] # can construct Variable from list!
    atts = dict(name='shp_type', long_name='Type of Shape', units='')
    tgt.addVariable(Variable(data=shape_type, axes=(shpax,), atts=atts), asNC=True, copy=True)    
    # collect rasterized masks from shape files 
    mask_array = np.zeros((len(shpax),)+srcgrd.size[::-1], dtype=np.bool) 
    # N.B.: rasterize() returns mask in (y,x) shape, size is ordered as (x,y)
    shape_masks = []; shp_full = []; shp_empty = []; shp_encl = []
    for i,shape in enumerate(shape_dict.itervalues()):
      mask = shape.rasterize(griddef=srcgrd, asVar=False)
      mask_array[i,:] = mask
      masksum = mask.sum() 
      lfull = masksum == 0; shp_full.append( lfull )
      lempty = masksum == mask.size; shp_empty.append( lempty )
      shape_masks.append( mask if not lempty else None )
      if lempty: shp_encl.append( False )
      else:
        shp_encl.append( np.all( mask[[0,-1],:] == True ) and np.all( mask[:,[0,-1]] == True ) )
        # i.e. if boundaries are masked
    # N.B.: shapes that have no overlap with grid will be skipped and filled with NaN
    # add rasterized masks to new dataset
    atts = dict(name='shp_mask', long_name='Rasterized Shape Mask', units='')
    tgt.addVariable(Variable(data=mask_array, atts=atts, axes=(shpax,srcgrd.ylat.copy(),srcgrd.xlon.copy())), 
                    asNC=True, copy=True)
    # add area enclosed by shape
    da = srcgrd.geotransform[1]*srcgrd.geotransform[5]
    mask_area = (1-mask_array).mean(axis=2).mean(axis=1)*da
    atts = dict(name='shp_area', long_name='Area Contained in the Shape', 
                units= 'm^2' if srcgrd.isProjected else 'deg^2' )
    tgt.addVariable(Variable(data=mask_area, axes=(shpax,), atts=atts), asNC=True, copy=True)
    # add flag to indicate if shape is fully enclosed by domain
    atts = dict(name='shp_encl', long_name='If Shape is fully included in Domain', units= '')
    tgt.addVariable(Variable(data=shp_encl, axes=(shpax,), atts=atts), asNC=True, copy=True)
    # add flag to indicate if shape fully covers domain
    atts = dict(name='shp_full', long_name='If Shape fully covers Domain', units= '')
    tgt.addVariable(Variable(data=shp_full, axes=(shpax,), atts=atts), asNC=True, copy=True)
    # add flag to indicate if shape and domain have no overlap
    atts = dict(name='shp_empty', long_name='If Shape and Domain have no Overlap', units= '')
    tgt.addVariable(Variable(data=shp_empty, axes=(shpax,), atts=atts), asNC=True, copy=True)
    # save all the meta data
    tgt.sync()
    # prepare function call    
    function = functools.partial(self.processShapeAverage, masks=shape_masks, ylat=ylat, xlon=xlon, 
                                 shpax=shpax, memory=memory) # already set parameters
    # start process
    if self.feedback: print('\n   +++   processing shape/area averaging   +++   ') 
    self.process(function, **kwargs) # currently 'flush' is the only kwarg
    if self.feedback: print('\n')
    if self.tmp: self.tmpput = self.target
    if ltmptoo: assert self.tmpput.name == 'tmptoo' # set above, when temp. dataset is created    
  # the previous method sets up the process, the next method performs the computation
  def processShapeAverage(self, var, masks=None, ylat=None, xlon=None, shpax=None, memory=500):
    ''' Compute masked area averages from variable data. 'memory' controls the garbage collection 
        interval approximately corresponds to MB in RAM.'''
    # process gdal variables (if a variable has a horiontal grid, it should be GDAL enabled)
    if var.gdal and ( np.issubdtype(var.dtype,np.integer) or np.issubdtype(var.dtype,np.inexact) ):
      if self.feedback: print('\n'+var.name),
      assert var.hasAxis(xlon) and var.hasAxis(ylat)
      assert len(masks) == len(shpax)
      tgt = self.target
      assert tgt.hasAxis(shpax, strict=False) and shpax not in var.axes 
      # assemble new axes
      axes = [tgt.getAxis(shpax.name)]      
      for ax in var.axes:
        if ax not in (xlon,ylat) and ax.name != shpax.name: # these axes are just transferred 
          axes.append(tgt.getAxis(ax.name))
      # N.B.: shape axis well be outer axis
      axes = tuple(axes)
      # pre-allocate
      shape = tuple(len(ax) for ax in axes)
      tgtdata = np.zeros(shape, dtype=np.float32) 
      # now we loop over all shapes/masks      
      if self.feedback: 
        varname = var.name
        print '\n ... loading  ',varname 
      var.load()
      if self.feedback: 
        varname = var.name
        print '\n ... averaging ',varname 
      ## compute shape averages for each time step
      # Basically, for each shape the entire array is masked, using the shape and the broadcasting 
      # functionality for horizontal masks of the Variable class (which creates a lot of overhead);
      # then the masked average is taken and the process is repeated for the next shape/mask.
      # Using mapMean creates a lot of overhead and there are probably more efficient ways to do it. 
      if var.ndim == 2:
        for i,mask in enumerate(masks): 
          if mask is None: tgtdata[i] = np.NaN # NaN for missing values (i.e. no overlap)
          else: tgtdata[i] = var.mapMean(mask=mask, asVar=False, squeeze=True) # compute the averages
          if self.feedback: print varname, i
          # garbage collection is typically not necessary for 2D fields
      elif var.ndim > 2:
        cnt = float(0.); inc = float(var.data_array.nbytes) / (1024.*1024.) # counter/increment to estimate memory useage (in MB)
        for i,mask in enumerate(masks):
          if mask is None: tgtdata[i,:] = np.NaN # NaN for missing values (i.e. no overlap) 
          else:
            tgtdata[i,:] = var.mapMean(mask=mask, asVar=False, squeeze=True).filled(np.NaN) # mapMean returns a masked array
            # N.B.: this is necessary, because sometimes shapes only contain invalid values
            cnt += inc # keep track of memory
            # N.B.: mapMean creates a lot of temporary arrays that don't get garbage-collected
          if self.feedback: print varname, i, cnt
          if cnt > memory: 
            cnt = 0 # reset counter
            gc.collect() # collect garbage in certain itnervals
            if self.feedback: print 'garbage collected'  
      else: raise AxisError 
      # create new Variable
      assert shape == tgtdata.shape
      newvar = var.copy(axes=axes, data=tgtdata) # new axes and data
      del tgtdata, mask # clean up (just to make sure)      
      gc.collect() # clean
    else:
      var.load() # need to load variables into memory to copy it (and we are not doing anything else...)
      newvar = var # just pass over the variable to the new dataset
    # return variable
    return newvar
  # function pair to extract station data from a time-series (or climatology)      
  def Extract(self, template=None, stnax=None, xlon=None, ylat=None, laltcorr=True, **kwargs):
    ''' Extract station data points from gridded datasets; calls processExtract. 
        A station dataset can be passed as template (must have station coordinates. '''
    if not self.source.gdal: raise DatasetError, "Source dataset must be GDAL enabled! {:s} is not.".format(self.source.name)
    if template is None: raise NotImplementedError
    elif isinstance(template, Dataset):
      if not template.hasAxis('station'): raise DatasetError, "Template station dataset needs to have a station axis."
      if not ( (template.hasVariable('lat') or template.hasVariable('stn_lat')) and 
               (template.hasVariable('lon') or template.hasVariable('stn_lon')) ): 
        raise DatasetError, "Template station dataset needs to have lat/lon arrays for the stations."      
    else: raise TypeError
    # make temporary dataset
    if self.source is self.target:
      if self.tmp: assert self.source == self.tmpput and self.target == self.tmpput
      # the operation can not be performed "in-place"!
      self.target = Dataset(name='tmptoo', title='Temporary target dataset for non-in-place operations', varlist=[], atts={})
      ltmptoo = True
    else: ltmptoo = False
    src = self.source; tgt = self.target # short-cuts 
    # determine source dataset grid definition
    if src.griddef is None:  
      srcgrd = GridDefinition(projection=self.source.projection, geotransform=self.source.geotransform, 
                              size=self.source.mapSize, xlon=self.source.xlon, ylat=self.source.ylat)
    else: srcgrd = src.griddef
    # figure out horizontal axes (will be replaced with station axis)
    if isinstance(xlon,Axis): 
      if not src.hasAxis(xlon, check=True): raise DatasetError
    elif isinstance(xlon,basestring): xlon = src.getAxis(xlon)
    else: xlon = src.x if srcgrd.isProjected else src.lon
    if isinstance(ylat,Axis):
      if not src.hasAxis(ylat, check=True): raise DatasetError
    elif isinstance(ylat,basestring): ylat = src.getAxis(ylat)
    else: ylat = src.y if srcgrd.isProjected else src.lat
    if stnax: # not in source dataset!
      if src.hasAxis(stnax, check=True): raise DatasetError, "Source dataset must not have a 'station' axis!"
    elif template: stnax = template.station # station axis
    else: raise ArgumentError, "A station axis needs to be supplied." 
    assert isinstance(xlon,Axis) and isinstance(ylat,Axis) and isinstance(stnax,Axis)
    # transform to dataset-native coordinate system
    if template: 
      if template.hasVariable('lat'): lats = template.lat.getArray()
      else: lats = template.stn_lat.getArray()
      if template.hasVariable('lon'): lons = template.lon.getArray()
      else: lons = template.stn_lon.getArray()
    else: raise NotImplementedError, "Cannot extract station data without a station template Dataset"
    # adjust longitudes
    if srcgrd.isProjected:
      if lons.max() > 180.: lons = np.where(lons > 180., 360.-lons, lons)
      # reproject coordinate
      latlon = osr.SpatialReference() 
      latlon.SetWellKnownGeogCS('WGS84') # a normal lat/lon coordinate system
      tx = osr.CoordinateTransformation(latlon,srcgrd.projection)
      xs = []; ys = [] 
      for i in xrange(len(lons)):
        x,y,z = tx.TransformPoint(lons[i].astype(np.float64),lats[i].astype(np.float64))
        xs.append(x); ys.append(y); del z
      lons = np.array(xs); lats = np.array(ys)
      #lons,lats = tx.TransformPoints(lons,lats) # doesn't seem to work...
    else:
      if lons.min() < 0. and xlon.coord.max() > 180.: lons = np.where(lons < 0., lons + 360., lons)
      elif lons.max() > 180. and xlon.coord.min() < 0.: lons = np.where(lons > 180., 360.-lons, lons)
      else: pass # source and template do not conflict
    # generate index list
    ixlon = []; iylat = []; istn = []; zs_err = [] # also record elevation error
    lzs = src.hasVariable('zs')
    lstnzs = template.hasVariable('zs') or  template.hasVariable('stn_zs')
    if laltcorr and lzs and lstnzs:
      if src.zs.ndim > 2: src.zs = src.zs(time=0, lidx=True) # first time-slice (for CESM)
      if src.zs.ndim != 2 or not src.gdal or src.zs.units != 'm': raise VariableError
      # consider altidue of surrounding points as well      
      zs = src.zs.getArray(unmask=True,fillValue=-300)
      if template.hasVariable('zs'): stn_zs = template.zs.getArray(unmask=True,fillValue=-300)
      else: stn_zs = template.stn_zs.getArray(unmask=True,fillValue=-300)
      if src.zs.axisIndex(xlon.name) == 0: zs.transpose() # assuming lat,lon or y,x order is more common
      ye,xe = zs.shape # assuming order lat,lon or y,x
      xe -= 1; ye -= 1 # last valid index, not length
      for n,lon,lat in zip(xrange(len(stnax)),lons,lats):
        ip = xlon.getIndex(lon, mode='left', outOfBounds=True)
        jp = ylat.getIndex(lat, mode='left', outOfBounds=True)
        if ip is not None and jp is not None:
          # find neighboring point with smallest altitude error 
#           ip = im+1 if im < xe else im  
#           jp = jm+1 if jm < ye else jm
          im = ip-1 if ip > 0 else ip  
          jm = jp-1 if jp > 0 else jp
          zdiff = np.Infinity # initialize, so that it triggers at least once
          # check four closest grid points
          for i in im,ip:
            for j in jm,jp:
              ze = zs[j,i]-stn_zs[n]
              zd = np.abs(ze) # compute elevation error
              if zd < zdiff: ii,jj,zdiff,zerr = i,j,zd,ze # preliminary selection, triggers at least once               
          ixlon.append(ii); iylat.append(jj); istn.append(n); zs_err.append(zerr) # final selection          
    else: 
      # just choose horizontally closest point 
      for n,lon,lat in zip(xrange(len(stnax)),lons,lats):
        i = xlon.getIndex(lon, mode='closest', outOfBounds=True)
        j = ylat.getIndex(lat, mode='closest', outOfBounds=True)
        if i is not None and j is not None: 
          if lzs: # compute elevation error
            zs_err.append(zs[j,i]-stn_zs[n])          
          ixlon.append(i); iylat.append(j); istn.append(n)
    # N.B.: it is necessary to append, because we don't know the number of valid points
    ixlon = np.array(ixlon); iylat = np.array(iylat); istn = np.array(istn); zs_err = np.array(zs_err)
    # prepare target dataset
    # N.B.: attributes should already be set in target dataset (by caller module)
    #       we are also assuming the new dataset has no axes yet
    assert len(tgt.axes) == 0
    # add axes from source data
    for axname,ax in src.axes.iteritems():
      if axname not in (xlon.name,ylat.name):
        tgt.addAxis(ax, asNC=True, copy=True)
    # add station axis (trim to valid coordinates)
    newstnax = stnax.copy(coord=stnax.coord[istn]) # same but with trimmed coordinate array
    tgt.addAxis(newstnax, asNC=True, copy=True) # already new copy
    # create variable for elevation error
    if lzs:
      assert len(zs_err) > 0
      zs_err = Variable(name='zs_err', units='m', data=zs_err, axes=(newstnax,),
                        atts=dict(long_name='Station Elevation Error'))
      tgt.addVariable(zs_err, asNC=True, copy=True); del zs_err # need to copy to make NC var
    # add a bunch of other variables with station meta data
    for var in template.variables.itervalues():
      if var.ndim == 1 and var.hasAxis(stnax): # station attributes
        if var.name[-4:] != '_len' or var.name == 'stn_rec_len': # exclude certain attributes
          newvar = var.copy(data=var.getArray()[istn], axes=(newstnax,))
          if newvar.name[:4] != 'stn_' and newvar.name[:8] != 'station_' and newvar.name[:8] != 'cluster_': 
            newvar.name = 'stn_'+newvar.name # copy cluster_* as they are!
          # N.B.: we need to rename, or name collisions will happen! 
          tgt.addVariable(newvar, asNC=True, copy=True); del newvar # need to copy to make NC var
    # save all the meta data
    tgt.sync()
    # prepare function call    
    function = functools.partial(self.processExtract, ixlon=ixlon, iylat=iylat, ylat=ylat, xlon=xlon, stnax=stnax) # already set parameters
    # start process
    if self.feedback: print('\n   +++   processing point-data extraction   +++   ') 
    self.process(function, **kwargs) # currently 'flush' is the only kwarg
    if self.feedback: print('\n')
    if self.tmp: self.tmpput = self.target
    if ltmptoo: assert self.tmpput.name == 'tmptoo' # set above, when temp. dataset is created    
  # the previous method sets up the process, the next method performs the computation
  def processExtract(self, var, ixlon=None, iylat=None, ylat=None, xlon=None, stnax=None):
    ''' Extract grid poitns corresponding to stations. '''
    # process gdal variables (if a variable has a horiontal grid, it should be GDAL enabled)
    if var.gdal:
      if self.feedback: print('\n'+var.name),
      tgt = self.target
      assert xlon in var.axes and ylat in var.axes
      assert tgt.hasAxis(stnax, strict=False) and stnax not in var.axes 
      # assemble new axes
      axes = [tgt.getAxis(stnax.name)]      
      for ax in var.axes:
        if ax.name not in (xlon.name,ylat.name) and ax.name != stnax.name: # these axes are just transferred 
          axes.append(tgt.getAxis(ax.name))
      axes = tuple(axes)
      shape = tuple(len(ax) for ax in axes)
      srcdata = var.getArray(copy=False) # don't make extra copy
      # roll x & y axes to the front (xlon first, then ylat, then the rest)
      srcdata = np.rollaxis(srcdata, axis=var.axisIndex(ylat.name), start=0)
      srcdata = np.rollaxis(srcdata, axis=var.axisIndex(xlon.name), start=0)
      assert srcdata.shape == (len(xlon),len(ylat))+shape[1:]
      # here we extract the data points
      if srcdata.ndim == 2:
        tgtdata = srcdata[ixlon,iylat] # constructed above
      elif srcdata.ndim > 2:
        tgtdata = srcdata[ixlon,iylat,:] # constructed above
      else: raise AxisError
      #try: except: print srcdata.shape, [slc.max() for slc in slices] 
      # create new Variable
      assert shape == tgtdata.shape
      newvar = var.copy(axes=axes, data=tgtdata) # new axes and data
      del srcdata, tgtdata # clean up (just to make sure)      
    else:
      var.load() # need to load variables into memory, because we are not doing anything else...
      newvar = var # just pass over the variable to the new dataset
    # return variable
    return newvar
    
  # function pair to compute a climatology from a time-series      
  def Regrid(self, griddef=None, projection=None, geotransform=None, size=None, xlon=None, ylat=None, 
             lmask=True, int_interp=None, float_interp=None, **kwargs):
    ''' Setup climatology and start computation; calls processClimatology. '''
    # make temporary gdal dataset
    if self.source is self.target:
      if self.tmp: assert self.source == self.tmpput and self.target == self.tmpput
      # the operation can not be performed "in-place"!
      self.target = Dataset(name='tmptoo', title='Temporary target dataset for non-in-place operations', varlist=[], atts={})
      ltmptoo = True
    else: ltmptoo = False 
    # make sure the target dataset is a GDAL-enabled dataset
    if 'gdal' in self.target.__dict__: 
      # gdal info alread present      
      if griddef is not None or projection is not None or geotransform is not None: 
        raise AttributeError, "Target Dataset '%s' is already GDAL enabled - cannot overwrite settings!"%self.target.name
      if self.target.xlon is None: raise GDALError, "Map axis 'xlon' not found!"
      if self.target.ylat is None: raise GDALError, "Map axis 'ylat' not found!"
      xlon = self.target.xlon; ylat = self.target.ylat
    else:
      # need to set GDAL parameters
      if self.tmp and 'gdal' in self.output.__dict__:
        # transfer gdal settings from output to temporary dataset 
        assert self.target is not self.output 
        projection = self.output.projection; geotransform = self.output.geotransform
        xlon = self.output.xlon; ylat = self.output.ylat
      else:
        # figure out grid definition from input 
        if griddef is None: 
          griddef = GridDefinition(projection=projection, geotransform=geotransform, size=size, xlon=xlon, ylat=ylat)
        # pass arguments through GridDefinition, if not provided
        projection=griddef.projection; geotransform=griddef.geotransform
        xlon=griddef.xlon; ylat=griddef.ylat                     
      # apply GDAL settings target dataset 
      for ax in (xlon,ylat): self.target.addAxis(ax, loverwrite=True) # i.e. replace if already present
      self.target = addGDALtoDataset(self.target, projection=projection, geotransform=geotransform)
    # use these map axes
    xlon = self.target.xlon; ylat = self.target.ylat
    assert isinstance(xlon,Axis) and isinstance(ylat,Axis)
    # determine source dataset grid definition
    if self.source.griddef is None:  
      srcgrd = GridDefinition(projection=self.source.projection, geotransform=self.source.geotransform, 
                              size=self.source.mapSize, xlon=self.source.xlon, ylat=self.source.ylat)
    else: srcgrd = self.source.griddef
    srcres = srcgrd.scale; tgtres = griddef.scale
    # determine if shift is necessary to insure correct wrapping
    if not srcgrd.isProjected and not griddef.isProjected:
      lwrapSrc = srcgrd.wrap360
      lwrapTgt = griddef.wrap360
      # check grids
      for grd in (srcgrd,griddef):
        if grd.wrap360:            
          assert grd.geotransform[0] + grd.geotransform[1]*(len(grd.xlon)-1) > 180        
          assert np.round(grd.geotransform[1]*len(grd.xlon), decimals=2) == 360 # require 360 deg. to some accuracy... 
          assert any( grd.xlon.getArray() > 180 ) # need to wrap around
          assert all( grd.xlon.getArray() >= 0 )
          assert all( grd.xlon.getArray() <= 360 )
        else:
          assert grd.geotransform[0] + grd.geotransform[1]*(len(grd.xlon)-1) < 180
          assert all( grd.xlon.getArray() >= -180 )
          assert all( grd.xlon.getArray() <= 180 )  
    else: 
      lwrapSrc = False # no need to shift, if a projected grid is involved!
      lwrapTgt = False # no need to shift, if a projected grid is involved!
    # determine GDAL interpolation
    if int_interp is None: int_interp = gdalInterp('nearest')
    else: int_interp = gdalInterp(int_interp)
    if float_interp is None:
      if srcres < tgtres: float_interp = gdalInterp('convolution') # down-sampling: 'convolution'
      else: float_interp = gdalInterp('cubicspline') # up-sampling
    else: float_interp = gdalInterp(float_interp)      
    # prepare function call    
    function = functools.partial(self.processRegrid, ylat=ylat, xlon=xlon, lwrapSrc=lwrapSrc, lwrapTgt=lwrapTgt, # already set parameters
                                 lmask=lmask, int_interp=int_interp, float_interp=float_interp)
    # start process
    if self.feedback: print('\n   +++   processing regridding   +++   ') 
    self.process(function, **kwargs) # currently 'flush' is the only kwarg
    # now make sure we have a GDAL dataset!
    self.target = addGDALtoDataset(self.target, griddef=griddef)
    if self.feedback: print('\n')
    if self.tmp: self.tmpput = self.target
    if ltmptoo: assert self.tmpput.name == 'tmptoo' # set above, when temp. dataset is created    
  # the previous method sets up the process, the next method performs the computation
  def processRegrid(self, var, ylat=None, xlon=None, lwrapSrc=False, lwrapTgt=False, lmask=True, int_interp=None, float_interp=None):
    ''' Compute a climatology from a variable time-series. '''
    # process gdal variables
    if var.gdal:
      if self.feedback: print('\n'+var.name),
      # replace axes
      axes = list(var.axes)
      axes[var.axisIndex(var.ylat)] = ylat
      axes[var.axisIndex(var.xlon)] = xlon
      # create new Variable
      var.load() # most rebust way to determine the dtype! and we need it later anyway
      newvar = var.copy(axes=axes, data=None, projection=self.target.projection) # and, of course, load new data
      # if necessary, shift array back, to ensure proper wrapping of coordinates
      # prepare regridding
      # get GDAL dataset instances
      srcdata = var.getGDAL(load=True, wrap360=lwrapSrc)
      tgtdata = newvar.getGDAL(load=False, wrap360=lwrapTgt, allocate=True, fillValue=var.fillValue)
      # determine GDAL interpolation
      if 'gdal_interp' in var.__dict__: gdal_interp = var.gdal_interp
      elif 'gdal_interp' in var.atts: gdal_interp = var.atts['gdal_interp'] 
      else: # use default based on variable type
        if np.issubdtype(var.dtype, np.integer): gdal_interp = int_interp # can't process logicals anyway...
        else: gdal_interp = float_interp                          
      # perform regridding
      err = gdal.ReprojectImage(srcdata, tgtdata, var.projection.ExportToWkt(), newvar.projection.ExportToWkt(), gdal_interp)
      #print srcdata.ReadAsArray().std(), tgtdata.ReadAsArray().std()
      #print var.projection.ExportToWkt()
      #print newvar.projection.ExportToWkt()
      del srcdata # clean up (just to make sure)
      # N.B.: the target array should be allocated and prefilled with missing values, otherwise ReprojectImage
      #       will just fill missing values with zeros!  
      if err != 0: raise GDALError, 'ERROR CODE %i'%err
      #tgtdata.FlushCash()  
      # load data into new variable
      newvar.loadGDAL(tgtdata, mask=lmask, wrap360=lwrapTgt, fillValue=var.fillValue)      
      del tgtdata # clean up (just to make sure)
    else:
      var.load() # need to load variables into memory, because we are not doing anything else...
      newvar = var # just pass over the variable to the new dataset
    # return variable
    return newvar
  
  # function pair to compute a climatology from a time-series      
  def Climatology(self, timeAxis='time', climAxis=None, period=None, offset=0, shift=0, timeSlice=None, **kwargs):
    ''' Setup climatology and start computation; calls processClimatology. '''
    if period is not None and not isinstance(period,(np.integer,int)): raise TypeError # period in years
    if not isinstance(offset,(np.integer,int)): raise TypeError # offset in years (from start of record)
    if not isinstance(shift,(np.integer,int)): raise TypeError # shift in month (if first month is not January)
    # construct new time axis for climatology
    if climAxis is None:        
      climAxis = Axis(name=timeAxis, units='month', length=12, coord=np.arange(1,13,1), dtype=dtype_int) # monthly climatology
    else: 
      if not isinstance(climAxis,Axis): raise TypeError
    # add axis to output dataset    
    if self.target.hasAxis(climAxis.name): 
      self.target.repalceAxis(climAxis, check=False) # will have different shape
    else: 
      self.target.addAxis(climAxis, copy=True) # copy=True allows recasting as, e.g., a NC variable
    climAxis = self.target.axes[timeAxis] # make sure we have exactly that instance
    # figure out time slice
    if period is not None:
      start = offset * len(climAxis); end = start + period * len(climAxis)
      timeSlice = slice(start,end,None)
    else: 
      if not isinstance(timeSlice,slice): raise TypeError, timeSlice
    # add variables that will cause errors to ignorelist (e.g. strings)
    for varname,var in self.source.variables.iteritems():
      if var.hasAxis(timeAxis) and var.dtype.kind == 'S': self.ignorelist.append(varname)
    # prepare function call
    function = functools.partial(self.processClimatology, # already set parameters
                                 timeAxis=timeAxis, climAxis=climAxis, timeSlice=timeSlice, shift=shift)
    # start process
    if self.feedback: print('\n   +++   processing climatology   +++   ')     
    if self.source.gdal: griddef = self.source.griddef
    else: griddef = None 
    self.process(function, **kwargs) # currently 'flush' is the only kwarg    
    # add GDAL to target
    if griddef is not None:
      self.target = addGDALtoDataset(self.target, griddef=griddef)
    # N.B.: if the dataset is empty, it wont do anything, hence we do it now    
    if self.feedback: print('\n')    
  # the previous method sets up the process, the next method performs the computation
  def processClimatology(self, var, timeAxis='time', climAxis=None, timeSlice=None, shift=0):
    ''' Compute a climatology from a variable time-series. '''
    # process variable that have a time axis
    if var.hasAxis(timeAxis):
      if self.feedback: print('\n'+var.name),
      # prepare averaging
      tidx = var.axisIndex(timeAxis)
      interval = len(climAxis)
      newshape = list(var.shape)
      newshape[tidx] = interval # shape of the climatology field  
      if not (interval == 12): raise NotImplementedError
      # load data
      if timeSlice is not None:
        idx = tuple([timeSlice if ax.name == timeAxis else slice(None) for ax in var.axes])
      else: idx = None
      dataarray = var.getArray(idx=idx, unmask=False, copy=False)    
      if var.masked: avgdata = ma.zeros(newshape, dtype=var.dtype) # allocate array
      else: avgdata = np.zeros(newshape, dtype=var.dtype) # allocate array    
      # average data
      timelength = dataarray.shape[tidx]
      if timelength % interval == 0:
        # use array indexing
        climelts = np.arange(interval, dtype=dtype_int)
        for t in xrange(0,timelength,interval):
          if self.feedback: print('.'), # t/interval+1
          avgdata += dataarray.take(t+climelts, axis=tidx)
        del dataarray # clean up
        # normalize
        avgdata /= (timelength/interval) 
      else: 
        # simple indexing
        climcnt = np.zeros(interval, dtype=dtype_int)
        for t in xrange(timelength):
          if self.feedback and t%interval == 0: print('.'), # t/interval+1
          idx = int(t%interval)
          climcnt[idx] += 1
          if dataarray.ndim == 1:
            avgdata[idx] = avgdata[idx] + dataarray[t]
          else: 
            avgdata[idx,:] = avgdata[idx,:] + dataarray[t,:]
        del dataarray # clean up
        # normalize
        for i in xrange(interval):
          if avgdata.ndim == 1:
            if climcnt[i] > 0: avgdata[i] /= climcnt[i]
            else: avgdata[i] = 0 if np.issubdtype(var.dtype, np.integer) else np.NaN
          else:
            if climcnt[i] > 0: avgdata[i,:] /= climcnt[i]
            else: avgdata[i,:] = 0 if np.issubdtype(var.dtype, np.integer) else np.NaN
      # shift data (if first month was not January)
      if shift != 0: avgdata = np.roll(avgdata, shift, axis=tidx)
      # create new Variable
      axes = tuple([climAxis if ax.name == timeAxis else ax for ax in var.axes]) # exchange time axis
      newvar = var.copy(axes=axes, data=avgdata, dtype=var.dtype) # and, of course, load new data
      del avgdata # clean up - just to make sure
      #     print newvar.name, newvar.masked
      #     print newvar.fillValue
      #     print newvar.data_array.__class__
    else:
      var.load() # need to load variables into memory, because we are not doing anything else...
      newvar = var.copy()
    # return variable
    return newvar
  
  def Shift(self, shift=0, axis=None, byteShift=False, **kwargs):
    ''' Method to initialize shift along a coordinate axis. '''
    # kwarg input
    if shift == 0 and axis == None:
      for key,value in kwargs.iteritems():
        if self.target.hasAxis(key) or self.input.hasAxis(key):
          if axis is None: axis = key; shift = value
          else: raise ProcessError, "Can only process one coordinate shift at a time."
      del kwargs[axis] # remove entry 
    # check input
    if isinstance(axis,basestring):
      if self.target.hasAxis(axis): axis = self.target.axes[axis]
      elif self.input.hasAxis(axis): axis = self.input.axes[axis].copy()
      else: raise AxisError, "Axis '%s' not found in Dataset."%axis
    else: 
      if not isinstance(axis,Axis): raise TypeError
    # apply shift to new axis
    if byteShift:
      # shift coordinate vector like data
      coord = np.roll(axis.getArray(unmask=False), shift) # 1-D      
    else:              
      coord = axis.getArray(unmask=False) + shift # shift coordinates
      # transform coordinate shifts into index shifts (linear scaling)
      shift = int( shift / (axis[1] - axis[0]) )    
    axis.coord = coord
    # add axis to output dataset      
    if self.target.hasAxis(axis, strict=True): pass
    elif self.target.hasAxis(axis.name): self.target.repalceAxis(axis)
    else: self.target.addAxis(axis, copy=True) # copy=True allows recasting as, e.g., a NC variable
    axis = self.target.axes[axis.name] # make sure we have the right version!
    # prepare function call
    function = functools.partial(self.processShift, # already set parameters
                                 shift=shift, axis=axis)
    # start process
    if self.feedback: print('\n   +++   processing shift/roll   +++   ')     
    self.process(function, **kwargs) # currently 'flush' is the only kwarg    
    if self.feedback: print('\n')
  # the previous method sets up the process, the next method performs the computation
  def processShift(self, var, shift=None, axis=None):
    ''' Method that shifts a data array along a given axis. '''
    # only process variables that have the specified axis
    if var.hasAxis(axis.name):
      if self.feedback: print('\n'+var.name), # put line break before test, instead of after      
      # shift data array
      newdata = np.roll(var.getArray(unmask=False), shift, axis=var.axisIndex(axis))
      # create new Variable
      axes = tuple([axis if ax.name == axis.name else ax for ax in var.axes]) # replace axis with shifted version
      newvar = var.copy(axes=axes, data=newdata) # and, of course, load new data
      var.unload(); del var, newdata
    else:
      var.load() # need to load variables into memory, because we are not doing anything else...
      newvar = var  
      var.unload(); del var
    # return variable
    return newvar
Ejemplo n.º 12
0
 def ShapeAverage(self, shape_dict=None, shape_name=None, shpax=None, xlon=None, ylat=None, 
                  memory=500, **kwargs):
   ''' Average over a limited area of a gridded datasets; calls processAverageShape. 
       A dictionary of NamedShape objects is expected to define the averaging areas. 
       'memory' controls the garbage collection interval and approximately corresponds 
       to MB in temporary (it does not include loading the variable into RAM, though). '''
   if not self.source.gdal: raise DatasetError, "Source dataset must be GDAL enabled! {:s} is not.".format(self.source.name)
   if not isinstance(shape_dict,OrderedDict): raise TypeError
   if not all(isinstance(shape,NamedShape) for shape in shape_dict.itervalues()): raise TypeError
   # make temporary dataset
   if self.source is self.target:
     if self.tmp: assert self.source == self.tmpput and self.target == self.tmpput
     # the operation can not be performed "in-place"!
     self.target = Dataset(name='tmptoo', title='Temporary target dataset for non-in-place operations', varlist=[], atts={})
     ltmptoo = True
   else: ltmptoo = False
   src = self.source; tgt = self.target # short-cuts 
   # determine source dataset grid definition
   if src.griddef is None:  
     srcgrd = GridDefinition(projection=self.source.projection, geotransform=self.source.geotransform, 
                             size=self.source.mapSize, xlon=self.source.xlon, ylat=self.source.ylat)
   else: srcgrd = src.griddef
   # figure out horizontal axes (will be replaced with station axis)
   if isinstance(xlon,Axis): 
     if not src.hasAxis(xlon, check=True): raise DatasetError
   elif isinstance(xlon,basestring): xlon = src.getAxis(xlon)
   else: xlon = src.x if srcgrd.isProjected else src.lon
   if isinstance(ylat,Axis):
     if not src.hasAxis(ylat, check=True): raise DatasetError
   elif isinstance(ylat,basestring): ylat = src.getAxis(ylat)
   else: ylat = src.y if srcgrd.isProjected else src.lat
   # check/create shapes axis
   if shpax: # not in source dataset!
     # if shape axis supplied
     if src.hasAxis(shpax, check=True): raise DatasetError, "Source dataset must not have a 'shape' axis!"
     if len(shpax) != len(shape_dict): raise AxisError
   else:
     # creat shape axis, if not supplied
     shpatts = dict(name='shape', long_name='Ordinal Number of Shape', units='#')
     shpax = Axis(coord=np.arange(1,len(shape_dict)+1), atts=shpatts) # starting at 1
   assert isinstance(xlon,Axis) and isinstance(ylat,Axis) and isinstance(shpax,Axis)
   # prepare target dataset
   # N.B.: attributes should already be set in target dataset (by caller module)
   #       we are also assuming the new dataset has no axes yet
   assert len(tgt.axes) == 0
   # add station axis (trim to valid coordinates)
   tgt.addAxis(shpax, asNC=True, copy=True) # already new copy
   # add axes from source data
   for axname,ax in src.axes.iteritems():
     if axname not in (xlon.name,ylat.name):
       tgt.addAxis(ax, asNC=True, copy=True)
   # add shape names
   shape_names = [shape.name for shape in shape_dict.itervalues()] # can construct Variable from list!
   atts = dict(name='shape_name', long_name='Name of Shape', units='')
   tgt.addVariable(Variable(data=shape_names, axes=(shpax,), atts=atts), asNC=True, copy=True)
   # add proper names
   shape_long_names = [shape.long_name for shape in shape_dict.itervalues()] # can construct Variable from list!
   atts = dict(name='shp_long_name', long_name='Proper Name of Shape', units='')
   tgt.addVariable(Variable(data=shape_long_names, axes=(shpax,), atts=atts), asNC=True, copy=True)    
   # add shape category
   shape_type = [shape.shapetype for shape in shape_dict.itervalues()] # can construct Variable from list!
   atts = dict(name='shp_type', long_name='Type of Shape', units='')
   tgt.addVariable(Variable(data=shape_type, axes=(shpax,), atts=atts), asNC=True, copy=True)    
   # collect rasterized masks from shape files 
   mask_array = np.zeros((len(shpax),)+srcgrd.size[::-1], dtype=np.bool) 
   # N.B.: rasterize() returns mask in (y,x) shape, size is ordered as (x,y)
   shape_masks = []; shp_full = []; shp_empty = []; shp_encl = []
   for i,shape in enumerate(shape_dict.itervalues()):
     mask = shape.rasterize(griddef=srcgrd, asVar=False)
     mask_array[i,:] = mask
     masksum = mask.sum() 
     lfull = masksum == 0; shp_full.append( lfull )
     lempty = masksum == mask.size; shp_empty.append( lempty )
     shape_masks.append( mask if not lempty else None )
     if lempty: shp_encl.append( False )
     else:
       shp_encl.append( np.all( mask[[0,-1],:] == True ) and np.all( mask[:,[0,-1]] == True ) )
       # i.e. if boundaries are masked
   # N.B.: shapes that have no overlap with grid will be skipped and filled with NaN
   # add rasterized masks to new dataset
   atts = dict(name='shp_mask', long_name='Rasterized Shape Mask', units='')
   tgt.addVariable(Variable(data=mask_array, atts=atts, axes=(shpax,srcgrd.ylat.copy(),srcgrd.xlon.copy())), 
                   asNC=True, copy=True)
   # add area enclosed by shape
   da = srcgrd.geotransform[1]*srcgrd.geotransform[5]
   mask_area = (1-mask_array).mean(axis=2).mean(axis=1)*da
   atts = dict(name='shp_area', long_name='Area Contained in the Shape', 
               units= 'm^2' if srcgrd.isProjected else 'deg^2' )
   tgt.addVariable(Variable(data=mask_area, axes=(shpax,), atts=atts), asNC=True, copy=True)
   # add flag to indicate if shape is fully enclosed by domain
   atts = dict(name='shp_encl', long_name='If Shape is fully included in Domain', units= '')
   tgt.addVariable(Variable(data=shp_encl, axes=(shpax,), atts=atts), asNC=True, copy=True)
   # add flag to indicate if shape fully covers domain
   atts = dict(name='shp_full', long_name='If Shape fully covers Domain', units= '')
   tgt.addVariable(Variable(data=shp_full, axes=(shpax,), atts=atts), asNC=True, copy=True)
   # add flag to indicate if shape and domain have no overlap
   atts = dict(name='shp_empty', long_name='If Shape and Domain have no Overlap', units= '')
   tgt.addVariable(Variable(data=shp_empty, axes=(shpax,), atts=atts), asNC=True, copy=True)
   # save all the meta data
   tgt.sync()
   # prepare function call    
   function = functools.partial(self.processShapeAverage, masks=shape_masks, ylat=ylat, xlon=xlon, 
                                shpax=shpax, memory=memory) # already set parameters
   # start process
   if self.feedback: print('\n   +++   processing shape/area averaging   +++   ') 
   self.process(function, **kwargs) # currently 'flush' is the only kwarg
   if self.feedback: print('\n')
   if self.tmp: self.tmpput = self.target
   if ltmptoo: assert self.tmpput.name == 'tmptoo' # set above, when temp. dataset is created    
Ejemplo n.º 13
0
        # imports
        from glob import glob
        from geodata.base import Dataset, Axis, Variable
        from geodata.netcdf import writeNetCDF

        # load list if well files and generate list of wells
        well_files = glob(os.path.join(data_folder, 'W*.xlsx'))
        well_files.sort()
        wells = [os.path.basename(name[:-5]) for name in well_files]
        print(wells)

        # dataset
        time_ax = Axis(coord=np.arange(12 * (period[1] - period[0])) + 252,
                       **varatts['time'])  # origin: 1979-01
        well_ax = Axis(coord=np.arange(len(wells)) + 1, name='well', units='')
        dataset = Dataset(name=conservation_authority,
                          title=conservation_authority + ' Observation Wells')
        # add meta data
        meta_dicts = [
            loadMetadata(well, conservation_authority=conservation_authority)
            for well in wells
        ]
        for key in meta_dicts[0].keys():
            if key in varatts: atts = varatts[key]
            elif key.lower() in varatts: atts = varatts[key.lower()]
            else: atts = dict(name=key, units='')
            if atts['units']:
                data = np.asarray([wmd[key] for wmd in meta_dicts],
                                  dtype=np.float64)
            else:
                data = np.asarray([wmd[key] for wmd in meta_dicts])
            try:
Ejemplo n.º 14
0
def performRegridding(dataset,
                      mode,
                      griddef,
                      dataargs,
                      loverwrite=False,
                      varlist=None,
                      lwrite=True,
                      lreturn=False,
                      ldebug=False,
                      lparallel=False,
                      pidstr='',
                      logger=None):
    ''' worker function to perform regridding for a given dataset and target grid '''
    # input checking
    if not isinstance(dataset, basestring): raise TypeError
    if not isinstance(dataargs, dict):
        raise TypeError  # all dataset arguments are kwargs
    if not isinstance(griddef, GridDefinition): raise TypeError
    if lparallel:
        if not lwrite:
            raise IOError, 'Can only write to disk in parallel mode (i.e. lwrite = True).'
        if lreturn:
            raise IOError, 'Can not return datasets in parallel mode (i.e. lreturn = False).'

    # logging
    if logger is None:  # make new logger
        logger = logging.getLogger()  # new logger
        logger.addHandler(logging.StreamHandler())
    else:
        if isinstance(logger, basestring):
            logger = logging.getLogger(name=logger)  # connect to existing one
        elif not isinstance(logger, logging.Logger):
            raise TypeError, 'Expected logger ID/handle in logger KW; got {}'.format(
                str(logger))

    ## extract meta data from arguments
    dataargs, loadfct, srcage, datamsgstr = getMetaData(
        dataset, mode, dataargs)
    dataset_name = dataargs.dataset_name
    periodstr = dataargs.periodstr
    avgfolder = dataargs.avgfolder

    # get filename for target dataset and do some checks
    filename = getTargetFile(
        dataset=dataset,
        mode=mode,
        dataargs=dataargs,
        lwrite=lwrite,
        grid=griddef.name.lower(),
    )

    # prepare target dataset
    if ldebug: filename = 'test_' + filename
    if not os.path.exists(avgfolder):
        raise IOError, "Dataset folder '{:s}' does not exist!".format(
            avgfolder)
    lskip = False  # else just go ahead
    if lwrite:
        if lreturn:
            tmpfilename = filename  # no temporary file if dataset is passed on (can't rename the file while it is open!)
        else:
            if lparallel: tmppfx = 'tmp_regrid_{:s}_'.format(pidstr[1:-1])
            else: tmppfx = 'tmp_regrid_'.format(pidstr[1:-1])
            tmpfilename = tmppfx + filename
        filepath = avgfolder + filename
        tmpfilepath = avgfolder + tmpfilename
        if os.path.exists(filepath):
            if not loverwrite:
                age = datetime.fromtimestamp(os.path.getmtime(filepath))
                # if source file is newer than sink file or if sink file is a stub, recompute, otherwise skip
                if age > srcage and os.path.getsize(filepath) > 1e6:
                    lskip = True
                    if hasattr(griddef,
                               'filepath') and griddef.filepath is not None:
                        gridage = datetime.fromtimestamp(
                            os.path.getmtime(griddef.filepath))
                        if age < gridage: lskip = False
                # N.B.: NetCDF files smaller than 1MB are usually incomplete header fragments from a previous crashed

    # depending on last modification time of file or overwrite setting, start computation, or skip
    if lskip:
        # print message
        skipmsg = "\n{:s}   >>>   Skipping: file '{:s}' in dataset '{:s}' already exists and is newer than source file.".format(
            pidstr, filename, dataset_name)
        skipmsg += "\n{:s}   >>>   ('{:s}')\n".format(pidstr, filepath)
        logger.info(skipmsg)
    else:

        ## actually load datasets
        source = loadfct()  # load source
        # check period
        if 'period' in source.atts and dataargs.periodstr != source.atts.period:  # a NetCDF attribute
            raise DateError, "Specifed period is inconsistent with netcdf records: '{:s}' != '{:s}'".format(
                periodstr, source.atts.period)

        # print message
        if mode == 'climatology':
            opmsgstr = 'Regridding Climatology ({:s}) to {:s} Grid'.format(
                periodstr, griddef.name)
        elif mode == 'time-series':
            opmsgstr = 'Regridding Time-series to {:s} Grid'.format(
                griddef.name)
        else:
            raise NotImplementedError, "Unrecognized Mode: '{:s}'".format(mode)
        # print feedback to logger
        logger.info(
            '\n{0:s}   ***   {1:^65s}   ***   \n{0:s}   ***   {2:^65s}   ***   \n'
            .format(pidstr, datamsgstr, opmsgstr))
        if not lparallel and ldebug: logger.info('\n' + str(source) + '\n')

        ## create new sink/target file
        # set attributes
        atts = source.atts.copy()
        atts['period'] = periodstr
        atts['name'] = dataset_name
        atts['grid'] = griddef.name
        if mode == 'climatology':
            atts['title'] = '{:s} Climatology on {:s} Grid'.format(
                dataset_name, griddef.name)
        elif mode == 'time-series':
            atts['title'] = '{:s} Time-series on {:s} Grid'.format(
                dataset_name, griddef.name)

        # make new dataset
        if lwrite:  # write to NetCDF file
            if os.path.exists(tmpfilepath):
                os.remove(tmpfilepath)  # remove old temp files
            sink = DatasetNetCDF(folder=avgfolder,
                                 filelist=[tmpfilename],
                                 atts=atts,
                                 mode='w')
        else:
            sink = Dataset(atts=atts)  # ony create dataset in memory

        # initialize processing
        CPU = CentralProcessingUnit(source,
                                    sink,
                                    varlist=varlist,
                                    tmp=False,
                                    feedback=ldebug)

        # perform regridding (if target grid is different from native grid!)
        if griddef.name != dataset:
            # reproject and resample (regrid) dataset
            CPU.Regrid(griddef=griddef, flush=True)

        # get results
        CPU.sync(flush=True)

        # add geolocators
        sink = addGeoLocator(sink,
                             griddef=griddef,
                             lgdal=True,
                             lreplace=True,
                             lcheck=True)
        # N.B.: WRF datasets come with their own geolocator arrays - we need to replace those!

        # add length and names of month
        if mode == 'climatology' and not sink.hasVariable(
                'length_of_month') and sink.hasVariable('time'):
            addLengthAndNamesOfMonth(
                sink,
                noleap=True if dataset.upper() in ('WRF', 'CESM') else False)

        # print dataset
        if not lparallel and ldebug:
            logger.info('\n' + str(sink) + '\n')
        # write results to file
        if lwrite:
            sink.sync()
            writemsg = "\n{:s}   >>>   Writing to file '{:s}' in dataset {:s}".format(
                pidstr, filename, dataset_name)
            writemsg += "\n{:s}   >>>   ('{:s}')\n".format(pidstr, filepath)
            logger.info(writemsg)

            # rename file to proper name
            if not lreturn:
                sink.unload()
                sink.close()
                del sink  # destroy all references
                if os.path.exists(filepath):
                    os.remove(filepath)  # remove old file
                os.rename(
                    tmpfilepath,
                    filepath)  # this would also overwrite the old file...
            # N.B.: there is no temporary file if the dataset is returned, because an open file can't be renamed

        # clean up and return
        source.unload()
        del source, CPU
        if lreturn:
            return sink  # return dataset for further use (netcdf file still open!)
        else:
            return 0  # "exit code"
Ejemplo n.º 15
0
def performExport(dataset, mode, dataargs, expargs, bcargs, loverwrite=False, 
                  ldebug=False, lparallel=False, pidstr='', logger=None):
    ''' worker function to export ASCII rasters for a given dataset '''
    # input checking
    if not isinstance(dataset,basestring): raise TypeError
    if not isinstance(dataargs,dict): raise TypeError # all dataset arguments are kwargs 
    
    # logging
    if logger is None: # make new logger     
        logger = logging.getLogger() # new logger
        logger.addHandler(logging.StreamHandler())
    else:
        if isinstance(logger,basestring): 
            logger = logging.getLogger(name=logger) # connect to existing one
        elif not isinstance(logger,logging.Logger): 
            raise TypeError, 'Expected logger ID/handle in logger KW; got {}'.format(str(logger))
  
    ## extract meta data from arguments
    dataargs, loadfct, srcage, datamsgstr = getMetaData(dataset, mode, dataargs, lone=False)
    dataset_name = dataargs.dataset_name; periodstr = dataargs.periodstr; domain = dataargs.domain
    
    # figure out bias correction parameters
    if bcargs:
        bcargs = bcargs.copy() # first copy, then modify...
        bc_method = bcargs.pop('method',None)
        if bc_method is None: raise ArgumentError("Need to specify bias-correction method to use bias correction!")
        bc_obs = bcargs.pop('obs_dataset',None)
        if bc_obs is None: raise ArgumentError("Need to specify observational dataset to use bias correction!")
        bc_reference = bcargs.pop('reference',None)
        if bc_reference is None: # infer from experiment name
            if dataset_name[-5:] in ('-2050','-2100'): bc_reference = dataset_name[:-5] # cut of period indicator and hope for the best 
            else: bc_reference = dataset_name 
        bc_grid = bcargs.pop('grid',None)
        if bc_grid is None: bc_grid = dataargs.grid
        bc_domain = bcargs.pop('domain',None)
        if bc_domain is None: bc_domain = domain
        bc_varlist = bcargs.pop('varlist',None)
        bc_varmap = bcargs.pop('varmap',None)       
        bc_tag = bcargs.pop('tag',None) # an optional name extension/tag
        bc_pattern = bcargs.pop('file_pattern',None) # usually default in getPickleFile
        lgzip = bcargs.pop('lgzip',None) # if pickle is gzipped (None: auto-detect based on file name extension)
        # get name of pickle file (and folder)
        picklefolder = dataargs.avgfolder.replace(dataset_name,bc_reference)
        picklefile = getPickleFileName(method=bc_method, obs_name=bc_obs, gridstr=bc_grid, domain=bc_domain, 
                                       tag=bc_tag, pattern=bc_pattern)
        picklepath = '{:s}/{:s}'.format(picklefolder,picklefile)
        if lgzip:
            picklepath += '.gz' # add extension
            if not os.path.exists(picklepath): raise IOError(picklepath)
        elif lgzip is None:
            lgzip = False
            if not os.path.exists(picklepath):
                lgzip = True # assume gzipped file
                picklepath += '.gz' # try with extension...
                if not os.path.exists(picklepath): raise IOError(picklepath)
        elif not os.path.exists(picklepath): raise IOError(picklepath)
        pickleage = datetime.fromtimestamp(os.path.getmtime(picklepath))
        # determine age of pickle file and compare against source age
    else:
      bc_method = False 
      pickleage = srcage
    
    # parse export options
    expargs = expargs.copy() # first copy, then modify...
    lm3 = expargs.pop('lm3') # convert kg/m^2/s to m^3/m^2/s (water flux)
    expformat = expargs.pop('format') # needed to get FileFormat object
    exp_list= expargs.pop('exp_list') # this handled outside of export
    compute_list = expargs.pop('compute_list', []) # variables to be (re-)computed - by default all
    # initialize FileFormat class instance
    fileFormat = getFileFormat(expformat, bc_method=bc_method, **expargs)
    # get folder for target dataset and do some checks
    expname = '{:s}_d{:02d}'.format(dataset_name,domain) if domain else dataset_name
    expfolder = fileFormat.defineDataset(dataset=dataset, mode=mode, dataargs=dataargs, lwrite=True, ldebug=ldebug)
  
    # prepare destination for new dataset
    lskip = fileFormat.prepareDestination(srcage=max(srcage,pickleage), loverwrite=loverwrite)
  
    # depending on last modification time of file or overwrite setting, start computation, or skip
    if lskip:        
        # print message
        skipmsg =  "\n{:s}   >>>   Skipping: Format '{:s} for dataset '{:s}' already exists and is newer than source file.".format(pidstr,expformat,dataset_name)
        skipmsg += "\n{:s}   >>>   ('{:s}')\n".format(pidstr,expfolder)
        logger.info(skipmsg)              
    else:
            
      ## actually load datasets
      source = loadfct() # load source data
      # check period
      if 'period' in source.atts and dataargs.periodstr != source.atts.period: # a NetCDF attribute
          raise DateError, "Specifed period is inconsistent with netcdf records: '{:s}' != '{:s}'".format(periodstr,source.atts.period)
      
      # load BiasCorrection object from pickle
      if bc_method:      
          op = gzip.open if lgzip else open
          with op(picklepath, 'r') as filehandle:
              BC = pickle.load(filehandle) 
          # assemble logger entry
          bcmsgstr = "(performing bias-correction using {:s} from {:s} towards {:s})".format(BC.long_name,bc_reference,bc_obs)
      
      # print message
      if mode == 'climatology': opmsgstr = 'Exporting Climatology ({:s}) to {:s} Format'.format(periodstr, expformat)
      elif mode == 'time-series': opmsgstr = 'Exporting Time-series to {:s} Format'.format(expformat)
      elif mode[-5:] == '-mean': opmsgstr = 'Exporting {:s}-Mean ({:s}) to {:s} Format'.format(mode[:-5], periodstr, expformat)
      else: raise NotImplementedError, "Unrecognized Mode: '{:s}'".format(mode)        
      # print feedback to logger
      logmsg = '\n{0:s}   ***   {1:^65s}   ***   \n{0:s}   ***   {2:^65s}   ***   \n'.format(pidstr,datamsgstr,opmsgstr)
      if bc_method:
          logmsg += "{0:s}   ***   {1:^65s}   ***   \n".format(pidstr,bcmsgstr)
      logger.info(logmsg)
      if not lparallel and ldebug: logger.info('\n'+str(source)+'\n')
      
      # create GDAL-enabled target dataset
      sink = Dataset(axes=(source.xlon,source.ylat), name=expname, title=source.title, atts=source.atts.copy())
      addGDALtoDataset(dataset=sink, griddef=source.griddef)
      assert sink.gdal, sink
      
      # apply bias-correction
      if bc_method:
          source = BC.correct(source, asNC=False, varlist=bc_varlist, varmap=bc_varmap) # load bias-corrected variables into memory
        
      # N.B.: for variables that are not bias-corrected, data are not loaded immediately but on demand; this way 
      #       I/O and computing can be further disentangled and not all variables are always needed
      
      # compute intermediate variables, if necessary
      for varname in exp_list:
          variables = None # variable list
          var = None
          # (re-)compute variable, if desired...
          if varname in compute_list:
              if varname == 'precip': var = newvars.computeTotalPrecip(source)
              elif varname == 'waterflx': var = newvars.computeWaterFlux(source)
              elif varname == 'liqwatflx': var = newvars.computeLiquidWaterFlux(source)
              elif varname == 'netrad': var = newvars.computeNetRadiation(source, asVar=True)
              elif varname == 'netrad_bb': var = newvars.computeNetRadiation(source, asVar=True, lrad=False, name='netrad_bb')
              elif varname == 'netrad_bb0': var = newvars.computeNetRadiation(source, asVar=True, lrad=False, lA=False, name='netrad_bb0')
              elif varname == 'vapdef': var = newvars.computeVaporDeficit(source)
              elif varname in ('pet','pet_pm','petrad','petwnd') and 'pet' not in sink:
                  if 'petrad' in exp_list or 'petwnd' in exp_list:
                      variables = newvars.computePotEvapPM(source, lterms=True) # default; returns mutliple PET terms
                  else: var = newvars.computePotEvapPM(source, lterms=False) # returns only PET
              elif varname == 'pet_th': var = None # skip for now
                  #var = computePotEvapTh(source) # simplified formula (less prerequisites)
          # ... otherwise load from source file
          if var is None and variables is None and varname in source:
              var = source[varname].load() # load data (may not have to load all)
          #else: raise VariableError, "Unsupported Variable '{:s}'.".format(varname)
          # for now, skip variables that are None
          if var or variables:
              # handle lists as well
              if var and variables: raise VariableError, (var,variables)
              elif var: variables = (var,)
              for var in variables:
                  addGDALtoVar(var=var, griddef=sink.griddef)
                  if not var.gdal and isinstance(fileFormat,ASCII_raster):
                      raise GDALError, "Exporting to ASCII_raster format requires GDAL-enabled variables."
                  # add to new dataset
                  sink += var
      # convert units
      if lm3:
          for var in sink:
              if var.units == 'kg/m^2/s':
                  var /= 1000. # divide to get m^3/m^2/s
                  var.units = 'm^3/m^2/s' # update units
      
      # compute seasonal mean if we are in mean-mode
      if mode[-5:] == '-mean': 
          sink = sink.seasonalMean(season=mode[:-5], lclim=True)
          # N.B.: to remain consistent with other output modes, 
          #       we need to prevent renaming of the time axis
          sink = concatDatasets([sink,sink], axis='time', lensembleAxis=True)
          sink.squeeze() # we need the year-axis until now to distinguish constant fields; now remove
      
      # print dataset
      if not lparallel and ldebug:
          logger.info('\n'+str(sink)+'\n')
        
      # export new dataset to selected format
      fileFormat.exportDataset(sink)
        
      # write results to file
      writemsg =  "\n{:s}   >>>   Export of Dataset '{:s}' to Format '{:s}' complete.".format(pidstr,expname, expformat)
      writemsg += "\n{:s}   >>>   ('{:s}')\n".format(pidstr,expfolder)
      logger.info(writemsg)      
         
      # clean up and return
      source.unload(); #del source
      return 0 # "exit code"
Ejemplo n.º 16
0
def loadHGS_StnTS(station=None, varlist=None, varatts=None, folder=None, name=None, title=None,
                  start_date=None, end_date=None, run_period=15, period=None, lskipNaN=False, lcheckComplete=True,
                  basin=None, WSC_station=None, basin_list=None, filename=None, prefix=None, 
                  scalefactors=None, **kwargs):
  ''' Get a properly formatted WRF dataset with monthly time-series at station locations; as in
      the hgsrun module, the capitalized kwargs can be used to construct folders and/or names '''
  if folder is None or ( filename is None and station is None ): raise ArgumentError
  # try to find meta data for gage station from WSC
  HGS_station = station
  if basin is not None and basin_list is not None:
    station_name = station
    station = getGageStation(basin=basin, station=station if WSC_station is None else WSC_station, 
                             basin_list=basin_list) # only works with registered basins
    if station_name is None: station_name = station.name # backup, in case we don't have a HGS station name
    metadata = station.getMetaData() # load station meta data
    if metadata is None: raise GageStationError(name)
  else: 
    metadata = dict(); station = None; station_name =  None    
  # prepare name expansion arguments (all capitalized)
  expargs = dict(ROOT_FOLDER=root_folder, STATION=HGS_station, NAME=name, TITLE=title,
                 PREFIX=prefix, BASIN=basin, WSC_STATION=WSC_station)
  for key,value in metadata.items():
      if isinstance(value,basestring):
          expargs['WSC_'+key.upper()] = value # in particular, this includes WSC_ID
  if 'WSC_ID' in expargs: 
      if expargs['WSC_ID'][0] == '0': expargs['WSC_ID0'] = expargs['WSC_ID'][1:]
      else: raise DatasetError('Expected leading zero in WSC station ID: {}'.format(expargs['WSC_ID']))
  # exparg preset keys will get overwritten if capitalized versions are defined
  for key,value in kwargs.items():
    KEY = key.upper() # we only use capitalized keywords, and non-capitalized keywords are only used/converted
    if KEY == key or KEY not in kwargs: expargs[KEY] = value # if no capitalized version is defined
  # read folder and infer prefix, if necessary
  folder = folder.format(**expargs)
  if not os.path.exists(folder): raise IOError(folder)
  if expargs['PREFIX'] is None:
    with open('{}/{}'.format(folder,prefix_file), 'r') as pfx:
      expargs['PREFIX'] = prefix = ''.join(pfx.readlines()).strip()      
  # now assemble file name for station timeseries
  filename = filename.format(**expargs)
  filepath = '{}/{}'.format(folder,filename)
  if not os.path.exists(filepath): IOError(filepath)
  if station_name is None: 
      station_name = filename[filename.index('hydrograph.')+1:-4] if station is None else station
  # set meta data (and allow keyword expansion of name and title)
  metadata['problem'] = prefix
  metadata['station_name'] = metadata.get('long_name', station_name)
  if name is not None: name = name.format(**expargs) # name expansion with capitalized keyword arguments
  else: name = 'HGS_{:s}'.format(station_name)
  metadata['name'] = name; expargs['Name'] = name.title() # name in title format
  if title is None: title = '{{Name:s}} (HGS, {problem:s})'.format(**metadata)
  title = title.format(**expargs) # name expansion with capitalized keyword arguments
  metadata['long_name'] = metadata['title'] = title
  # now determine start data for date_parser
  if end_date is None: 
      if start_date and run_period: end_date = start_date + run_period 
      elif period: end_date = period[1]
      else: raise ArgumentError("Need to specify either 'start_date' & 'run_period' or 'period' to infer 'end_date'.")
  end_year,end_month,end_day = convertDate(end_date)
  if start_date is None: 
      if end_date and run_period: start_date = end_date - run_period 
      elif period: start_date = period[0]
      else: raise ArgumentError("Need to specify either 'end_date' & 'run_period' or 'period' to infer 'start_date'.")
  start_year,start_month,start_day = convertDate(start_date)
  if start_day != 1 or end_day != 1: 
    raise NotImplementedError('Currently only monthly data is supported.')
#   import functools
#   date_parser = functools.partial(date_parser, year=start_year, month=start_month, day=start_day)
#   # now load data using pandas ascii reader
#   data_frame = pd.read_table(filepath, sep='\s+', header=2, dtype=np.float64, index_col=['time'], 
#                              date_parser=date_parser, names=ascii_varlist)
#   # resample to monthly data
#   data_frame = data_frame.resample(resampling).agg(np.mean)
#       data = data_frame[flowvar].values
  # parse header
  if varlist is None: varlist = variable_list[:] # default list 
  with open(filepath, 'r') as f:
      line = f.readline(); lline = line.lower() # 1st line
      if not "hydrograph" in lline: raise GageStationError(line,filepath)
      # parse variables and determine columns
      line = f.readline(); lline = line.lower() # 2nd line
      if not "variables" in lline: raise GageStationError(line)
      variable_order = [v.strip('"').lower() for v in line[line.find('"'):].strip().split(',')]
  # figure out varlist and data columns
  if variable_order[0] == 'time': del variable_order[0] # only keep variables
  else: raise GageStationError(variable_order)
  variable_order = [hgs_variables[v] for v in variable_order] # replace HGS names with GeoPy names
  vardict = {v:i+1 for i,v in enumerate(variable_order)} # column mapping; +1 because time was removed
  variable_order = [v for v in variable_order if v in varlist or flow_to_flux[v] in varlist]
  usecols = tuple(vardict[v] for v in variable_order) # variable columns that need to loaded (except time, which is col 0)
  assert 0 not in usecols, usecols
  # load data as tab separated values
  data = np.genfromtxt(filepath, dtype=np.float64, delimiter=None, skip_header=3, usecols = (0,)+usecols)
  assert data.shape[1] == len(usecols)+1, data.shape
  if lskipNaN:
      data = data[np.isnan(data).sum(axis=1)==0,:]
  elif np.any( np.isnan(data) ):
      raise DataError("Missing values (NaN) encountered in hydrograph file; use 'lskipNaN' to ignore.\n('{:s}')".format(filepath))    
  time_series = data[:,0]; flow_data = data[:,1:]
  assert flow_data.shape == (len(time_series),len(usecols)), flow_data.shape
  # original time deltas in seconds
  time_diff = time_series.copy(); time_diff[1:] = np.diff(time_series) # time period between time steps
  assert np.all( time_diff > 0 ), filepath
  time_diff = time_diff.reshape((len(time_diff),1)) # reshape to make sure broadcasting works
  # integrate flow over time steps before resampling
  flow_data[1:,:] -= np.diff(flow_data, axis=0)/2. # get average flow between time steps
  flow_data *= time_diff # integrate flow in time interval by multiplying average flow with time period
  flow_data = np.cumsum(flow_data, axis=0) # integrate by summing up total flow per time interval
  # generate regular monthly time steps
  start_datetime = np.datetime64(dt.datetime(year=start_year, month=start_month, day=start_day), 'M')
  end_datetime = np.datetime64(dt.datetime(year=end_year, month=end_month, day=end_day), 'M')
  time_monthly = np.arange(start_datetime, end_datetime+np.timedelta64(1, 'M'), dtype='datetime64[M]')
  assert time_monthly[0] == start_datetime, time_monthly[0]
  assert time_monthly[-1] == end_datetime, time_monthly[-1] 
  # convert monthly time series to regular array of seconds since start date
  time_monthly = ( time_monthly.astype('datetime64[s]') - start_datetime.astype('datetime64[s]') ) / np.timedelta64(1,'s')
  assert time_monthly[0] == 0, time_monthly[0]
  # interpolate integrated flow to new time axis
  #flow_data = np.interp(time_monthly, xp=time_series[:,0], fp=flow_data[:,0],).reshape((len(time_monthly),1))
  time_series = np.concatenate(([0],time_series), axis=0) # integrated flow at time zero must be zero...
  flow_data = np.concatenate(([[0,]*len(usecols)],flow_data), axis=0) # ... this is probably better than interpolation
  # N.B.: we are adding zeros here so we don't have to extrapolate to the left; on the right we just fill in NaN's
  if ( time_monthly[-1] - time_series[-1] ) > 3*86400. and lcheckComplete: 
      warn("Data record ends more than 3 days befor end of period: {} days".format((time_monthly[-1]-time_series[-1])/86400.))
  elif (time_monthly[-1]-time_series[-1]) > 5*86400.: 
      if lcheckComplete: 
        raise DataError("Data record ends more than 5 days befor end of period: {} days".format((time_monthly[-1]-time_series[-1])/86400.))
      else:
        warn("Data record ends more than 5 days befor end of period: {} days".format((time_monthly[-1]-time_series[-1])/86400.))
  flow_interp = si.interp1d(x=time_series, y=flow_data, kind='linear', axis=0, copy=False, 
                            bounds_error=False, fill_value=np.NaN, assume_sorted=True) 
  flow_data = flow_interp(time_monthly) # evaluate with call
  # compute monthly flow rate from interpolated integrated flow
  flow_data = np.diff(flow_data, axis=0) / np.diff(time_monthly, axis=0).reshape((len(time_monthly)-1,1))
  flow_data *= 1000 # convert from m^3/s to kg/s
  # construct time axis
  start_time = 12*(start_year - 1979) + start_month -1
  end_time = 12*(end_year - 1979) + end_month -1
  time = Axis(name='time', units='month', atts=dict(long_name='Month since 1979-01'), 
              coord=np.arange(start_time, end_time)) # not including the last, e.g. 1979-01 to 1980-01 is 12 month
  assert len(time_monthly) == end_time-start_time+1
  assert flow_data.shape == (len(time),len(variable_order)), (flow_data.shape,len(time),len(variable_order))
  # construct dataset
  dataset = Dataset(atts=metadata)
  dataset.station = station # add gage station object, if available (else None)
  for i,flowvar in enumerate(variable_order):
      data = flow_data[:,i]
      fluxvar = flow_to_flux[flowvar]
      if flowvar in varlist:
        flowatts = variable_attributes[flowvar]
        # convert variables and put into dataset (monthly time series)
        if flowatts['units'] != 'kg/s': 
          raise VariableError("Hydrograph data is read as kg/s; flow variable does not match.\n{}".format(flowatts))
        dataset += Variable(data=data, axes=(time,), **flowatts)
      if fluxvar in varlist and 'shp_area' in metadata:
        # compute surface flux variable based on drainage area
        fluxatts = variable_attributes[fluxvar]
        if fluxatts['units'] == 'kg/s' and fluxatts['units'] != 'kg/m^2/s': raise VariableError(fluxatts)
        data = data / metadata['shp_area'] # need to make a copy
        dataset += Variable(data=data, axes=(time,), **fluxatts)
  # apply analysis period
  if period is not None:
      dataset = dataset(years=period)
  # adjust scalefactors, if necessary
  if scalefactors:
      if isinstance(scalefactors,dict):
          dataset = updateScalefactor(dataset, varlist=scalefactors, scalefactor=None)
      elif isNumber(scalefactors):
          scalelist = ('discharge','seepage','flow')
          dataset = updateScalefactor(dataset, varlist=scalelist, scalefactor=scalefactors)
      else: 
          raise TypeError(scalefactors) 
  # return completed dataset
  return dataset
Ejemplo n.º 17
0
def loadGageStation(basin=None, station=None, varlist=None, varatts=None, mode='climatology', 
                    aggregation=None, filetype='monthly', folder=None, name=None, period=None,
                    basin_list=None, lcheck=True, lexpand=True, lfill=True, lflatten=True,
                    lkgs=True, scalefactors=None, title=None):
  ''' function to load hydrograph climatologies and timeseries for a given basin '''
  ## resolve input
  if mode == 'timeseries' and aggregation: 
    raise ArgumentError('Timeseries does not support aggregation.')
  # get GageStation instance
  station = getGageStation(basin=basin, station=station, name=name, folder=folder, 
                           river=None, basin_list=basin_list, lcheck=True)
  # variable attributes
  if varlist is None: varlist = variable_list
  elif not isinstance(varlist,(list,tuple)): raise TypeError  
  varlist = list(varlist) # make copy of varlist to avoid interference
  if varatts is None: 
    if aggregation is None: varatts = variable_attributes_kgs if lkgs else variable_attributes_mms
    else: varatts = agg_varatts_kgs if lkgs else agg_varatts_mms
  elif not isinstance(varatts,dict): raise TypeError
  
  ## read csv data
  # time series data and time coordinates
  lexpand = True; lfill = True
  if mode == 'climatology': lexpand = False; lfill = False; lflatten = False
  data, time = station.getTimeseriesData(units='kg/s' if lkgs else 'm^3/s', lcheck=True, lexpand=lexpand, 
                                         lfill=lfill, period=period, lflatten=lflatten)
  # station meta data
  metadata = station.getMetaData(lcheck=True)
  den = metadata['shp_area'] if lkgs else ( metadata['shp_area'] / 1000. )
  ## create dataset for station
  dataset = Dataset(name='WSC', title=title or metadata['Station Name'], varlist=[], atts=metadata,) 
  if mode.lower() in ('timeseries','time-series'): 
    time = time.flatten(); data = data.flatten() # just to make sure...
    # make time axis based on time coordinate from csv file
    timeAxis = Axis(name='time', units='month', coord=time, # time series centered at 1979-01
                    atts=dict(long_name='Month since 1979-01'))
    dataset += timeAxis
    # load mean discharge
    dataset += Variable(axes=[timeAxis], data=data, atts=varatts['discharge'])
    # load mean runoff
    doa = data / den 
    dataset += Variable(axes=[timeAxis], data=doa, atts=varatts['runoff'])
  elif mode == 'climatology': 
    # N.B.: this is primarily for backwards compatibility; it should not be used anymore...
    # make common time axis for climatology
    te = 12 # length of time axis: 12 month
    climAxis = Axis(name='time', units='month', length=12, coord=np.arange(1,te+1,1)) # monthly climatology
    dataset.addAxis(climAxis, copy=False)
    # extract variables (min/max/mean are separate variables)
    # N.B.: this is mainly for backwards compatibility
    doa = data / den
    if aggregation is None or aggregation.lower() == 'mean':
      # load mean discharge
      tmpdata = nf.nanmean(data, axis=0)
      tmpvar = Variable(axes=[climAxis], data=tmpdata, atts=varatts['discharge'])
      dataset.addVariable(tmpvar, copy=False)
      # load mean runoff
      tmpdata = nf.nanmean(doa, axis=0)
      tmpvar = Variable(axes=[climAxis], data=tmpdata, atts=varatts['runoff'])
      dataset.addVariable(tmpvar, copy=False)
    if aggregation is None or aggregation.lower() == 'std':
      # load  discharge standard deviation
      tmpdata = nf.nanstd(data, axis=0, ddof=1) # very few values means large uncertainty!
      tmpvar = Variable(axes=[climAxis], data=tmpdata, atts=varatts['discstd'])
      dataset.addVariable(tmpvar, copy=False)
      # load  runoff standard deviation
      tmpdata = nf.nanstd(doa, axis=0, ddof=1)
      tmpvar = Variable(axes=[climAxis], data=tmpdata, atts=varatts['roff_std'])
      dataset.addVariable(tmpvar, copy=False)
    if aggregation is None or aggregation.lower() == 'sem':
      # load  discharge standard deviation
      tmpdata = nf.nansem(data, axis=0, ddof=1) # very few values means large uncertainty!
      tmpvar = Variable(axes=[climAxis], data=tmpdata, atts=varatts['discsem'])
      dataset.addVariable(tmpvar, copy=False)
      # load  runoff standard deviation
      tmpdata = nf.nansem(doa, axis=0, ddof=1)
      tmpvar = Variable(axes=[climAxis], data=tmpdata, atts=varatts['roff_sem'])
      dataset.addVariable(tmpvar, copy=False)
    if aggregation is None or aggregation.lower() == 'max':
      # load maximum discharge
      tmpdata = nf.nanmax(data, axis=0)
      tmpvar = Variable(axes=[climAxis], data=tmpdata, atts=varatts['discmax'])
      dataset.addVariable(tmpvar, copy=False)
      # load maximum runoff
      tmpdata = nf.nanmax(doa, axis=0)
      tmpvar = Variable(axes=[climAxis], data=tmpdata, atts=varatts['roff_max'])
      dataset.addVariable(tmpvar, copy=False)
    if aggregation is None or aggregation.lower() == 'min':
      # load minimum discharge
      tmpdata = nf.nanmin(data, axis=0)
      tmpvar = Variable(axes=[climAxis], data=tmpdata, atts=varatts['discmin'])
      dataset.addVariable(tmpvar, copy=False)
      # load minimum runoff
      tmpdata = nf.nanmin(doa, axis=0)
      tmpvar = Variable(axes=[climAxis], data=tmpdata, atts=varatts['roff_min'])
      dataset.addVariable(tmpvar, copy=False)
  else: 
    raise NotImplementedError, "Time axis mode '{}' is not supported.".format(mode)
  # adjust scalefactors, if necessary
  if scalefactors:
      if isinstance(scalefactors,dict):
          dataset = updateScalefactor(dataset, varlist=scalefactors, scalefactor=None)
      elif isNumber(scalefactors):
          scalelist = ('discharge','StdDisc','SEMDisc','MaxDisc','MinDisc',)
          dataset = updateScalefactor(dataset, varlist=scalelist, scalefactor=scalefactors)
      else: 
          raise TypeError(scalefactors) 
  # return station dataset
  return dataset