Ejemplo n.º 1
0
def generateStatistics(varname, ens, fit, scl=None, reference=None, mode='Ratio', plot_labels=None, 
                       nsamples=None, bootstrap_axis='bootstrap', lflatten=False, sample_axis='time', 
                       lcrossval=True):
  ''' Perform K-S test and compute ratio of means; return results in formatted string. '''
  # some average diagnosics
  idkey = 'dataset_name' if ens.basetype is Dataset else 'name'  
  varlist = Ensemble(*[ds[varname] for ds in ens if ds is not None and varname in ds], idkey=idkey)
  if not all(varlist[0].ndim==ndim for ndim in varlist.ndim):
    new_axes = varlist[np.argmax(varlist.ndim)].axes
    varlist = varlist.insertAxes(new_axes=new_axes, lcheckAxis=False)    
  mvars = varlist.mean() # growth rate
  lratio = mode.lower() == 'ratio'
  lshift = mode.lower() == 'shift'
  if plot_labels is None: plot_labels = dict()
  # figure out fillValue
  if np.issubdtype(varlist[0].dtype, np.floating): fillValue = np.NaN
  elif np.issubdtype(varlist[0].dtype, np.integer): fillValue = 0
  else: raise TypeError(varlist[0].dtype)
  # define reference
  if isinstance(reference,(list,tuple)): 
    reflist0 = list(reference); reference = reference[0]
  else: reflist0 = [] # dummy list
  if reference is None: iref0 = 0
  elif isinstance(reference,(int,np.integer)): iref0 = reference 
  elif isinstance(reference,str): iref0 = varlist.idkeys.index(reference)
  else: raise ArgumentError  
  # goodness of fit, reported on plot panels
  if fit:
    fitlist = Ensemble(*[ds[varname] for ds in fit if ds is not None and varname in ds], idkey=idkey)
    if any(fitlist.hasAxis(bootstrap_axis)): fitlist = fitlist(**{bootstrap_axis:0, 'lcheckAxis':False})
    if not all(fitlist[0].ndim==ndim for ndim in fitlist.ndim):
      new_axes = fitlist[np.argmax(fitlist.ndim)].axes
      fitlist = fitlist.insertAxes(new_axes=new_axes, lcheckAxis=False) 
#       for var in fitlist: 
#         print [ax.name for ax in var.axes], var.shape
#       assert  np.all(fitlist[0][1,:] == fitlist[0][2,:])
    assert not isinstance(reference,str) or iref0 == fitlist.idkeys.index(reference), reference
    if any([isinstance(dist,VarRV) for dist in fitlist]) or not scl:
      names = [plot_labels.get(getattr(dist,idkey),getattr(dist,idkey)) for dist in fitlist]  
      lnames = max([len(name) for name in names]) # allocate line space
      headline = 'Sample'; lhead = len(headline) # sample/exp header
      headline += ' '*max(lnames-lhead,0) # 'Exp.'+' '*max(lnames-4,0) if lnames < 8 else 'Experiment'
      string = '{:s}  Fit  {:s}\n'.format(headline,mode.title())
      namestr = '{{:>{:d}s}}  {{:s}}  '.format(max(lhead,lnames))
      iref = iref0; reflist = reflist0[:] # copy list
      for i,dist,var,name,mvar in zip(range(len(fitlist)),fitlist,varlist,names,mvars):
        if isinstance(dist,VarRV) or not scl:
          if isinstance(dist,VarRV):
            pval = dist.fittest(var, nsamples=nsamples, asVar=False, lcrossval=lcrossval) #lflatten=lflatten, axis_idx=var.axisIndex(sample_axis, lcheck=False))
#             print var.name, pval, pval.mean().__class__.__name__, '{:s}'.format(pval.mean())
#             pval = '{:3.2f}'.format(float(pval.mean())) # mean is only necessary to convert to scalar
            pval = '{:3.2f}'.format(float(np.median(pval))) # mean is only necessary to convert to scalar
            # for some reason masked array scalars appear string-type, rather than numbers... 
          else: pval = '  - '
          if len(reflist) > 0 and name == reflist[0]: # assign new reference 
            iref = i; del reflist[0] # pop element 
          if isinstance(mvar,np.ma.core.MaskedConstant) or isinstance(mvars[iref],np.ma.core.MaskedConstant): 
            string += namestr.format(name,' N/A\n')
          elif lratio: string += (namestr+'{:3.2f}\n').format(name,pval,(mvar/mvars[iref]).mean())
          elif lshift: string += (namestr+'{:+2.1f}\n').format(name,pval,(mvar-mvars[iref]).mean())
    else: string = ''
  else: raise NotImplementedError
  if scl:
    scllist = Ensemble(*[ds[varname] for ds in scl if ds is not None and varname in ds], idkey=idkey)
    bs_axes = scllist.axisIndex(bootstrap_axis, lcheck=False) # return None, if not present
    if bs_axes is None: bs_axes = [None]*len(scllist)
    scllist = scllist(**{bootstrap_axis:0, 'lcheckAxis':False})
    if not all(scllist[0].ndim==ndim for ndim in scllist.ndim):
      new_axes = scllist[np.argmax(scllist.ndim)].axes
      scllist = scllist.insertAxes(new_axes=new_axes, lcheckAxis=False) 
    assert not isinstance(reference,str) or iref0 == scllist.idkeys.index(reference), reference
    if len(scllist) != len(varlist): raise AxisError(scllist)
    # compute means
    mvars = []
    for svr,var in zip(scllist,varlist):
      if isinstance(svr,VarRV): mvar = svr.stats(moments='mv', asVar=False)[...,0] # only first moment
      else: mvar = var.mean()*svr.atts.get('loc_factor',1.)
      mvars.append(mvar)        
    # figure out label width and prepare header
    if len(varlist) > 1: # otherwise no comparison...
      names = [plot_labels.get(getattr(dist,idkey),getattr(dist,idkey)) for dist in scllist]  
      lnames = max([len(name) for name in names]) # allocate line space
      namestr = '{{:>{:d}s}}  {{:s}}  '.format(max(lhead,lnames))
      tmphead = 'Fit to {:s}:' if scl == fit else 'Rescaled to {:s}:' # new heading
      tmphead += ' '*(max(lnames-len(names[iref0]),0)+5)+'\n'
      string += tmphead.format(names[iref0])
      # prepare first reference sample for K-S test
      scale,shape = scllist[iref0].atts.get('scale_factor', 1),scllist[iref0].atts.get('shape_factor', 1)
      if not (scale is None or scale == 1) and not (shape is None or shape == 1): 
        raise NotImplementedError("Cannot rescale scale/variance and shape parameters of reference sample!")
      refsmpl = varlist[iref0].getArray(unmask=True, fillValue=fillValue) # only once
      loc0 = scllist[iref0].atts.get('loc_factor', 1)     
      refsmpl = _rescaleSample(refsmpl, loc0, bs_axis=bs_axes[iref0]) # apply rescaling (varies, dependign on loc-type)
  #     print varlist[iref0].dataset_name, [ax.name for ax in varlist[iref0].axes], refsmpl.shape, 
      # start loop
      iref = iref0; reflist = reflist0[:] # copy list
      for i,dist,varsmpl,mvar,bs_axis in zip(range(len(varlist)),scllist,varlist,mvars,bs_axes):
        name = getattr(dist,idkey)
        if len(reflist) > 0 and name == reflist[0]: # assign new reference 
          iref = i; del reflist[0] # pop element       
          # prepare subsequent reference sample for K-S test
          scale,shape = dist.atts.get('scale_factor', 1),dist.atts.get('shape_factor', 1)
          if not (scale is None or scale == 1) and not (shape is None or shape == 1): 
            raise NotImplementedError("Cannot rescale scale/variance and shape parameters of reference sample!")
          refsmpl = varsmpl.getArray(unmask=True, fillValue=fillValue) # only once
          if not varsmpl.atts.get('rescaled',False):
            refsmpl = _rescaleSample(refsmpl, dist.atts.get('loc_factor', 1), bs_axis=bs_axis) # apply rescaling (varies, dependign on loc-type)
        elif i != iref:
          scale,shape = dist.atts.get('scale_factor', 1),dist.atts.get('shape_factor', 1) 
          # perform K-S test
          if (scale is None or scale == 1) and (shape is None or shape == 1):
            # K-S test between actual samples is more realistic, and rescaling of mean is simple
            smpl = varsmpl.getArray(unmask=True, fillValue=fillValue) # only once
            if not varsmpl.atts.get('rescaled',False):
              smpl = _rescaleSample(smpl, dist.atts.get('loc_factor', 1), bs_axis=bs_axis) # apply rescaling (varies, dependign on loc-type)
  #           print varsmpl.dataset_name, [ax.name for ax in varsmpl.axes], smpl.shape
  #           print smpl.shape, np.nanmean(smpl), refsmpl.shape, np.nanmean(refsmpl)
  #           print lflatten, sample_axis
            pval = ks_2samp(refsmpl, smpl, asVar=False, lflatten=lflatten, 
                            axis_idx=varsmpl.axisIndex(sample_axis, lcheck=False))
  #           print dist.name, pval
  #           pval = '{:3.2f}'.format(float(pval.mean()))
            pval = '{:3.2f}'.format(float(np.median(pval)))
          else:
            # no straight-forward way to rescale samples, so have to compare distribution with 
            # reference sample, which means more noise (since the distribution will be randomly sampled)
            if isinstance(dist,VarRV): pval = '{:3.2f}'.format(float(dist.kstest(refsmpl).mean()))
            else: pval = '  - '
          # add column with ratio/difference of means after rescaling
          if name in plot_labels: name = plot_labels[name]  
          if isinstance(mvar,np.ma.core.MaskedConstant) or isinstance(mvars[iref],np.ma.core.MaskedConstant):
            string += namestr.format(name,' N/A\n') 
          elif lratio: string += (namestr+'{:3.2f}\n').format(name,pval,(mvar/mvars[iref]).mean())
          elif lshift: string += (namestr+'{:+2.1f}\n').format(name,pval,(mvar-mvars[iref]).mean())
  # return formatted table in string
  return string