Пример #1
0
def age_vs_plot(track, infile, ycol='logl', ax=None, annotate=True, xlabels=True,
                save_plot=True, ylabels=True):
    agb_mix = infile.agb_mix
    set_name = infile.set_name

    if ycol == 'logl':
        ydata= track.get_col('L_star')
        majL = MultipleLocator(.2)
        minL = MultipleLocator(.1)
        ylab = '$\log\ L_{\odot}$'
    elif ycol == 'logt':
        ydata = track.get_col('T_star')
        majL = MultipleLocator(.1)
        minL = MultipleLocator(.05)
        ylab = '$\log\ Te$'
    elif ycol == 'C/O':
        ydata = track.get_col('CO')
        majL = MaxNLocator(4)
        minL = MaxNLocator(2)
        ylab = '$C/O$'
    else:
        print 'logl, logt, C/O only choices for y.'
        return

    age = track.get_col('ageyr')
    addpt = track.addpt
    Qs = list(track.Qs)

    if ax is None:
        fig, ax = plt.subplots()
    ax.plot(age, ydata, color='black')
    ax.plot(age[Qs], ydata[Qs], 'o', color='green')
    if len(addpt) > 0:
        ax.plot(age[addpt], ydata[addpt], 'o', color='purple')
    ax.yaxis.set_major_locator(majL)
    ax.yaxis.set_minor_locator(minL)
    majorFormatter = ScalarFormatter()
    majorFormatter.set_powerlimits((-3, 4))
    ax.xaxis.set_major_formatter(majorFormatter)

    if annotate is True:
        ax.text(0.06, 0.87, '${\\rm %s}$' % agb_mix.replace('_', '\ '),
                transform=ax.transAxes)
        ax.text(0.06, 0.77,'${\\rm %s}$' % set_name.replace('_', '\ '),
                transform=ax.transAxes)
        ax.text(0.06, 0.67, '$M=%.2f$' % track.mass,
                transform=ax.transAxes)
    if ylabels is True:
        ax.set_ylabel('$%s$' % ylab, fontsize=20)
    if xlabels is True:
        ax.set_xlabel('$\rm{Age (yr)}$', fontsize=20)

    if save_plot is True:
        plotpath = os.path.join(infile.diagnostic_dir, 'age_v/')
        fileIO.ensure_dir(plotpath)
        fname = os.path.split(track.name)[1].replace('.dat', '')
        fig_name = os.path.join(plotpath, '_'.join(('diag', fname)))
        plt.savefig('%s_age_v.png' % fig_name, dpi=300)
        plt.close()
    return
Пример #2
0
def diag_plots(track, infile):
    agb_mix = infile.agb_mix
    set_name = infile.set_name
    ext = '.png'
    logt_lim = (3.75, 3.35)
    logl_lim = (2.4, 4.8)
    lage_lim = (1., 1e7)
    co_lim = (0, 5)

    logl = track.get_col('L_star')
    logt = track.get_col('T_star')
    addpt = track.addpt
    Qs = list(track.Qs)
    # HRD
    fig = plt.figure()
    ax = plt.axes()
    plotpath = os.path.join(infile.diagnostic_dir, 'HRD/')
    fileIO.ensure_dir(plotpath)
    ax.annotate(r'$%s$' % agb_mix.replace('_', '\ '), xy=(3.43, 2.8))
    ax.annotate(r'$%s$' % set_name, xy=(3.43, 2.7))
    ax.annotate(r'$M=%.2f$' % track.mass, xy=(3.43, 2.6))
    ax.plot(logt, logl, color='black')

    ax.plot(logt[Qs], logl[Qs], color='green', lw=2)
    ax.plot(logt[Qs], logl[Qs], 'o', color='green')
    if len(addpt) > 0:
        ax.plot(logt[addpt], logl[addpt], 'o', color='purple')
    ax.set_xlim(logt_lim)
    ax.set_ylim(logl_lim)
    ax.set_xlabel(r'$\log\ Te$')
    ax.set_ylabel(r'$\log\ L_{\odot}$')
    fname = os.path.split(track.name)[1].replace('.dat', '')
    fig_name = os.path.join(plotpath, '_'.join(('diag', fname)))
    plt.savefig('%s%s' % (fig_name, ext))
    plt.close()

    fig, (axs) = plt.subplots(nrows=3)
    ycols = ['logl', 'logt', 'C/O']
    annotate = [False, False, True]
    save_plot = [False, False, True]
    xlabels = [False, False, True]
    for i in range(len(ycols)):
        age_vs_plot(track, infile, ycol=ycols[i], ax=axs[i], annotate=annotate[i],
                    xlabels=xlabels[i], ylabels=True, save_plot=save_plot[i])
  def attach_experiment(self,experiment_dir,error_key='dev_mae',verbose=True):
    """
    Overide this to suit your purposes!
    The purpose of this function is to go over each experiment in the supplied experiment_dir and aggregate the best results. 
    the results into a nice dictionary object. I encourage this to be overidden in ones own implementation if their experiment structure
    is different. Alternativley, if an aggregation can be supplied to attach_aggregation, your life will be easier.

    :param experiment_dir: String pointing to valid experiment directory as defined by experimentGenerator
    :returns: 0 on success, <0 for any error.
    
    .. side-effects:: Upon success, the class variable results will be populated. 
    .. note:: A valid experiment directory (if using this as is) should have sufficient permissions set and contain a var directory. The var directory should contain sub folders for each experiment, labeled like: alpha_1__gamma_2/. Underneath each sub directory there should be the folder logs underwhich there should be a file runtime.txt like this logs/runtime.txt. 
    """

    dirs = []   
    results_dir = '%s/%s' %(experiment_dir,'results')
    var_dir     = '%s/%s' %(experiment_dir,'var')
    if not os.path.isdir(var_dir):
      print('There is no var directory. This directory should have all of the sub experiments beneath it.')
      return -1
    if not os.path.isdir(results_dir):
      if ensure_dir(results_dir) < 0:
        print('There is no results directory. Failed attempt to create directory here: %s.' % (results_dir))
        return -2

    print('Building list of directories...') 
    dirs = list()
    num_sep = var_dir.count(os.path.sep)
    level = 1
    for root, ds, fs in os.walk(var_dir):
      num_sep_this = root.count(os.path.sep)
      if num_sep + level <= num_sep_this:
        del ds[:] 
        continue
      dirs = ds
    
    num_experiments = len(dirs)
    results         = {}

    # ---------------------------------
    # Go over each experiment directory
    # -----------------------
    print('Aggregating results.') 
    num_dirs = len(dirs) 
    param_count = None
    last_params = None
    for p,d in enumerate(dirs):
      runtime_fn = '%s/var/%s/logs/runtime.txt' % (experiment_dir,d) 
      if not os.path.isfile(runtime_fn):
        continue
            
      d_end = d.split('/')[-1] 
      # Hyper parameters are in the name of a directory
      hyper_params = self.parse_key_values(d_end,assignment='_',delimeter='__')
      
      runtime = open(runtime_fn,'r').read().split('\n')    
      min_error = float('inf')
      min_error_exp= None
      
      # -----------------------------------------
      # Gather the results for a given experiment
      # -------------------
      for line in runtime:
        if not line:
          continue
        rt = self.parse_key_values(line)
        # ----
        # See if this current line in the runtime log had the best error.
        try: 
          error = float(rt[error_key])
          if error < min_error:
            min_error = error
            min_error_exp = copy.deepcopy(rt)
        except ValueError:
          raise ValueError('Expecting error key %s to evaluate to a float.' % (error_key))
        except KeyError:
          print('ERROR: Parsing file:',runtime_fn)
          print('- Error key %s does not apear on all lines.' % error_key)
          break
        # -------
      
      # ------- 
      if not min_error_exp:
        if verbose:
          print('WARNING: Unable to find best result in experiment %s.' % (runtime_fn))
        continue

      # Add all the hyper params to the best experiment found
      for key in hyper_params:
        min_error_exp[key] = float(hyper_params[key])
      # Update the values pointed to by every key in min_error_exp
      if not param_count:
        param_count = len(min_error_exp)
        last_params = min_error_exp
      elif len(min_error_exp) != param_count:
        print("Error in file:",runtime_fn)
        print("Irregular number of outputs per line.")
        continue
      else:
        param_count = len(min_error_exp)  
        last_params = min_error_exp

      for key in min_error_exp:
        if not key in results:
          results[key] = dict()
          results[key]['values'] = list()
          results[key]['__longest_sequence__'] = -1  
        val = min_error_exp[key] 
        try:
          val = float(val)
        except:
          pass
        max_ = max(len(str(val)),len(key),results[key]['__longest_sequence__'])
        results[key]['__longest_sequence__'] = max_
        results[key]['values'].append(val)

      if( p % 200 == 0):
        progress(p,num_experiments)
    
    # -------

    progress(num_experiments,num_experiments)
    print('(%d) results' % (len(results)))
    self.results = results
    return 0