def summary(models, path='.', pars=('R', 'M', 'Mni', 'E'), sep=' ', is_verbose=False): """ Print list main param :param models: list of model names :param path: working directory. Default: ./ :param pars: the parameters for summary. Default: pars=('R', 'M', 'Mni', 'E') :param sep: string separation. Default: space :param is_verbose: :return: """ data = {k: [] for k in pars} for mdl in models: stella = Stella(mdl, path=path) if stella.is_tt: tt_info = stella.get_tt().Info try: for k in data.keys(): v = getattr(tt_info, k) if v not in data[k]: data[k].append(v) except KeyError as ex: print(" KeyError: %s. %s" % (tt_info.Name, ex)) else: if is_verbose: print("{0} No tt.".format(stella)) if len(list(data.values())[0]) > 0: print('Summary in {}'.format(path)) for k in data.keys(): print("{:5s} {}".format(k, sep.join(map(str, sorted(data[k]))))) else: if is_verbose: print('No tt-files in {}'.format(path))
def compare_ttVSubv(mname, path, bands=('U', 'B', 'V', 'R', 'I'), t_cut=1., is_plot_time_points=False): dic_results = {} model = Stella(mname, path=path) tt = model.get_tt().load() # time cut days mags = tt[tt['time'] > t_cut] n1 = mags.dtype.names def f(x): if len(x) == 2 and x.startswith('M'): return x.replace('M', '') else: return x n2 = map(lambda x: f(x), n1) mags.dtype.names = n2 dic_results['tt'] = mags # ubv serial_spec = model.read_series_spectrum(t_diff=1.05) mags = serial_spec.mags_bands(bands) dic_results['ubv'] = mags plot_all(dic_results, bands, title=mname, is_time_points=is_plot_time_points)
def long(models, path='.', pars=('R', 'M', 'Mni', 'E'), sep=' |', is_verbose=False): """ Print model list with parameters TODO filter by parameters :param is_verbose: :param sep: :param models: list of model names :param path: working directory. Default: ./ :param pars: the list of printed parameters :return: """ units = {'R': 'Rsun', 'M': 'Msun', 'Mni': 'Msun', 'E': 'FOE'} is_dict = type(models) is dict if is_dict: mnames = list(models.keys()) else: mnames = models # print header s1 = '{:>40s} '.format('Name') s2 = '{:40s} '.format('-' * 40) for k in pars: s1 += '{} {:>3s}({:4s})'.format(sep, k, units[k]) s2 += '{} {:8s}'.format(sep, '-' * 8) s1 += '{} {}'.format(sep, 'comment') s2 += '{} {}'.format(sep, '-' * 8) print(s1) print(s2) # print("| %40s | %8s | %6s | %6s | %s" % ('Name', 'R', 'M', 'E', 'comment')) # print("| %40s | %8s | %6s | %6s | %s" % ('-' * 40, '-' * 8, '-' * 6, '-' * 6, '-' * 8)) models_data = {} for mdl in mnames: stella = Stella(mdl, path=path) exts = models[mdl] if is_dict else '' dat = {} if stella.is_tt: info = stella.get_tt().Info try: s = '{:>40s} '.format(info.Name) for k in pars: v = getattr(info, k) s += '{} {:8.3f}'.format(sep, v) dat[k] = v s += '{} {}'.format(sep, exts) print(s) except KeyError as ex: print("| %40s | %8s | %6s | %6s | %s | KeyError: %s" % (info.Name, '', '', '', exts, ex)) except: print("| %40s | %8s | %6s | %6s | %s | Unexpected error: %s" % (info.Name, '', '', '', exts, sys.exc_info()[0])) models_data[mdl] = dat else: if is_verbose: print("| %40s | %26s | %s" % (stella.name, ' ', exts)) return models_data
def plot_chi_par(res_sorted, path='./', p=('R', 'M', 'E'), **kwargs): from pystella.model.stella import Stella is_show = kwargs.get('is_show', False) # find parameters i = 0 datatt = {} for name, res_chi in res_sorted.items(): i += 1 stella = Stella(name, path=path) if stella.is_tt: try: info = stella.get_tt().Info v = [getattr(info, pp) for pp in p] datatt[name] = {"v": v, 'chi': res_chi.measure} except KeyError as ex: print("Error for model {}. Message: {} ".format(name, ex)) if len(datatt) == 0: print('There are no tt-data for any models.') return import matplotlib.pyplot as plt font_size = kwargs.get('font_size', 10) markers_cycler = cycle({ u'D': u'diamond', 6: u'caretup', u's': u'square', u'x': u'x' }.keys()) num = len(p) nrow = int(num / 2.1) + 1 ncol = 2 if num > 1 else 1 fig = plt.figure(figsize=(12, nrow * 6)) plt.matplotlib.rcParams.update({'font.size': font_size}) for i, px in enumerate(p): ax = fig.add_subplot(nrow, ncol, i + 1) x = [] y = [] for name, val in datatt.items(): x.append(val['v'][i]) # get value of parameter y.append(val['chi']) # get chi ax.semilogy(x, y, marker=next(markers_cycler), ls="") ax.set_xlabel(px) ax.set_ylabel(r'$\chi^2$') if is_show: plt.show()
def info(path, cond=lambda i: True): """Print information list about models in the path :param path: working directory :param cond: condition function, like lambda i: 30 < i.M < 1000 and i.R > 100 :return: None """ from os import listdir from os.path import isfile, join files = [f for f in listdir(path) if isfile(join(path, f)) and f.endswith('.tt')] for f in files: # print 'Read: %s' % f name, ext = os.path.splitext(f) stella = Stella(name, path=path) ttinfo = stella.get_tt().Info # print(info.Data) if cond(ttinfo): # # if 30 < info.R < 40: ttinfo.show()
def plot_squared_3d(ax, res_sorted, path='./', p=('R', 'M', 'E'), is_rbf=True, **kwargs): from matplotlib import pyplot as plt from pystella.model.stella import Stella is_not_quiet = kwargs.get('is_not_quiet', False) is_polar = kwargs.get('is_polar', False) is_show = False # find parameters i = 0 # info_models = {} data = [] # data = np.empty(len(p), len(res_sorted)) chi = [] for name, res_chi in res_sorted.items(): i += 1 stella = Stella(name, path=path) if stella.is_tt: try: info = stella.get_tt().Info v = [getattr(info, pp) for pp in p] # print info if is_not_quiet: if i == 1: print("| %40s | %7s | %6s" % ('Model', p[0], p[1])) # print("| %40s | %7.2f | %6.2f" % (info.Name) + v) print("| {:40s} | ".format(info.Name) + ' '.join("{0:6.2f}".format(vv) for vv in v)) k = -1 for vo in data: k += 1 if np.array_equal(v, vo): if is_not_quiet: print( "| | " + ' '.join("{0:6.2f}".format(vv) for vv in v) + " | {:40s} | chi_saved={:6.2f} chi_new={:6.2f}" .format('This is not a unique point', chi[k], res_chi.measure)) if res_chi.measure < chi[k]: if is_not_quiet: print( "| %40s | k = %5d Chi [%7.2f] => [%6.2f]" % (info.Name, k, res_chi.measure, chi[k])) chi[k] = res_chi.measure break else: data.append(v) # y.append(v2) chi.append(res_chi.measure) except KeyError as ex: print("Error for model {}. Message: {} ".format(name, ex)) if len(data) == 0: print('There are no tt-data for any models.') return # plot x = [v[0] for v in data] y = [v[1] for v in data] z = [v[2] for v in data] x = np.array(x[::-1]) y = np.array(y[::-1]) z = np.array(z[::-1]) chi = np.array(chi[::-1]) if ax is None: fig = plt.figure(figsize=(12, 8)) if is_polar: ax = fig.add_subplot(1, 1, 1, projection='polar') else: ax = fig.add_subplot(1, 1, 1) is_show = True if is_polar: C = 1.9 theta = (x - np.min(x)) / (np.max(x) - np.min(x)) * C * np.pi # R width = y / np.max(y) * np.pi / 8 # M # radii = np.log10(chi+1) * 10 # chi radii = 10 + (chi - np.min(chi)) / (np.max(chi) - np.min(chi)) * 100 # chi bars = ax.bar(theta, radii, width=width, bottom=0.0) # Use custom colors and opacity labels = z for z, bar, l in zip(z / np.max(z), bars, labels): # bar.set_facecolor(plt.cm.jet(r)) # E bar.set_facecolor(plt.cm.viridis(z)) # bar.set_facecolor(plt.cm.plasma(z)) bar.set_alpha(0.5) bar.set_label(l) xlabel_max = np.min(x) + 7. / 4. / C * (np.max(x) - np.min(x)) xlabel = [ "{:6.0f} R".format(xx) for xx in np.linspace(np.min(x), xlabel_max, 8) ] ax.set_xticklabels(xlabel) ax.legend(bbox_to_anchor=(1.3, 1.05)) else: import matplotlib.pyplot as plt # graph # ax.set_title('-'.join(p)) ax.set_xlabel(p[0]) ax.set_ylabel(p[1]) ax.set_zlabel(p[2]) N = chi / np.max(chi) surf = ax.scatter(x, y, z, c=N, cmap=plt.cm.viridis) # surf = ax.scatter(x, y, z, c=N, cmap="gray") from matplotlib.ticker import LinearLocator from matplotlib.ticker import FormatStrFormatter ax.yaxis.set_major_locator(LinearLocator(10)) ax.yaxis.set_major_formatter(FormatStrFormatter('%.0f')) plt.colorbar(surf, shrink=0.5, aspect=5) ax.grid(True) if is_show: plt.show()
def plot_squared(ax, res_sorted, path='./', p=('R', 'M'), **kwargs): from matplotlib import pyplot as plt from pystella.model.stella import Stella is_rbf = kwargs.get('is_rbf', True) is_surface = kwargs.get('is_surface', True) is_scatter = kwargs.get('is_scatter', not is_surface and False) # is_not_quiet = False is_not_quiet = kwargs.get('is_not_quiet', False) # graph # ax.set_title('-'.join(p)) ax.set_xlabel(p[0]) ax.set_ylabel(p[1]) # find parameters i = 0 data = [] chi = [] models = [] for name, res_chi in res_sorted.items(): i += 1 stella = Stella(name, path=path) if stella.is_tt: try: info = stella.get_tt().Info v = [getattr(info, pp) for pp in p] # print info if is_not_quiet: if i == 1: print("| %40s | %7s | %6s" % ('Model', p[0], p[1])) print("| {:40s} | ".format(info.Name) + ' '.join("{0:6.2f}".format(vv) for vv in v)) k = -1 for vo in data: k += 1 if np.array_equal(v, vo): if is_not_quiet: print( "| | " + ' '.join("{0:6.2f}".format(vv) for vv in v) + " | {:40s} | chi_saved={:6.2f} chi_new={:6.2f}" .format('This is not a unique point', chi[k], res_chi.measure)) if res_chi.measure < chi[k]: if is_not_quiet: print( "| %50s | k = %5d Chi [%7.2f] < [%6.2f]" % (info.Name + ' '.join("{0:6.2f}".format(vv) for vv in v), k, res_chi.measure, chi[k])) chi[k] = res_chi.measure break else: models.append(name) data.append(v) chi.append(res_chi.measure) except KeyError as ex: print("Error for model {}. Message: {} ".format(name, ex)) if len(models) == 0: print('There are no tt-data for any models.') return # plot # x, y = map(np.array, zip(data)) x = [v[0] for v in data] y = [v[1] for v in data] x = np.array(x) y = np.array(y) chi = np.array(chi) if is_surface: # Set up a regular grid of interpolation points xi, yi = np.linspace(x.min(), x.max(), 100), np.linspace(y.min(), y.max(), 100) xi, yi = np.meshgrid(xi, yi) # Interpolate if is_rbf: rbf = sci.interpolate.Rbf(x, y, chi, function='linear') zi = rbf(xi, yi) else: zi = sci.interpolate.griddata((x, y), chi, (xi, yi), method="linear") # cmap: bone viridis RdBu im = ax.imshow(zi, cmap=plt.cm.viridis, vmin=chi.min(), vmax=chi.max(), origin='lower', extent=[x.min(), x.max(), y.min(), y.max()], interpolation='none', aspect='auto', alpha=0.5) # try: # from skimage import measure # # Find contours at a constant value # levels = np.linspace(np.min(chi), np.max(chi), 10) # levels = levels[1:len(levels)-2] # for level in levels: # contours = measure.find_contours(zi, level) # for n, contour in enumerate(contours): # lx, ly = contour[:, 1], contour[:, 0] # l, = ax.plot(lx, ly, linewidth=2, label="%.1f"%level) # pos = [(lx[-2] + lx[-1]) / 2., (ly[-2] + ly[-1]) / 2.] # # xscreen = ax.transData.transform(zip(lx[-2::], ly[-2::])) # # rot = np.rad2deg(np.arctan2(*np.abs(np.gradient(xscreen)[0][0][::-1]))) # rot = 0 # ltex = plt.text(pos[0], pos[1], "%.1f"%level, size=9, rotation=rot, # color=l.get_color(), ha="center", va="center", # bbox=dict(ec='1', fc='1')) # # # ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True) # except ImportError: cset = ax.contour(xi, yi, zi, linewidths=1., cmap=plt.cm.bone) # cset = ax.contour(xi, yi, zi, linewidths=2, cmap=plt.cm.Set2) ax.clabel(cset, inline=True, fmt='%1.1f', fontsize=9) cbar = plt.colorbar(im) cbar.ax.set_ylabel(r'$\chi^2$') # plt.setp(cb.ax.get_yticklabels(), visible=False) ax.scatter(x, y, c=chi / np.max(chi), cmap=plt.cm.bone, picker=True) # ax.set_picker(True) def on_pick(event): try: ind = event.ind[0] print('{} {}: R={} M={} chi^2={:.2f}'.format( ind, models[ind], x[ind], y[ind], chi[ind])) except AttributeError: pass ax.figure.canvas.mpl_connect('pick_event', on_pick) # def onclick(event): # print('button=%d, x=%d, y=%d, xdata=%f, ydata=%f' % # (event.button, event.x, event.y, event.xdata, event.ydata)) # ax.figure.canvas.mpl_connect('button_press_event', onclick) # ax.scatter(x, y, c=chi, cmap=plt.cm.RdBu) elif is_scatter: # Sort the points by density, so that the densest points are plotted last idx = chi.argsort()[::-1] x, y, chi = x[idx], y[idx], chi[idx] area = np.pi * (100 * np.log10(chi))**2 # 0 to 15 point radii # plt.scatter(x, y, s=area, c=colors, alpha=0.5) cax = plt.scatter(x, y, s=area, c=chi, cmap='gray', edgecolor='', alpha=0.5) plt.colorbar(cax) else: from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt # # Make data. xi, yi = np.linspace(x.min(), x.max(), 100), np.linspace(y.min(), y.max(), 100) xi, yi = np.meshgrid(xi, yi) z = chi # # # Interpolate if is_rbf: rbf = sci.interpolate.Rbf(x, y, chi, function='linear') zi = rbf(xi, yi) else: zi = sci.interpolate.griddata((x, y), chi, (xi, yi), method="linear") surf = ax.plot_trisurf(x, y, z, linewidth=0.2, antialiased=True, cmap='gray') # ax.plot_trisurf(xi, yi, zi, linewidth=0.2, antialiased=True) # # Plot the surface. # surf = ax.plot_surface(xi, yi, zi, cmap=plt.cm.coolwarm, # linewidth=0, antialiased=False) # # # Customize the z axis. # ax.set_zlim(-1.01, 1.01) # ax.zaxis.set_major_locator(LinearLocator(10)) # ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f')) # # # Add a color bar which maps values to colors. plt.colorbar(surf, shrink=0.5, aspect=5) plt.subplots_adjust(left=0.07, right=0.96, top=0.97, bottom=0.06)
def bad_compute_vel_res_tt(name, path, z=0., t_beg=1., t_end=None, t_diff=1.05, line_header=80): from scipy import interpolate model = Stella(name, path=path) # check data if not model.is_res: raise ValueError("There are no res-file for %s in the directory: %s " % (name, path)) if not model.is_tt: raise ValueError(("There are no tt-file for %s in the directory: %s " % (name, path))) if t_end is None: t_end = float('inf') res = model.get_res() tt = model.get_tt().load(line_header=line_header) tt = tt[tt['time'] >= t_beg] # time cut days Rph_spline = interpolate.splrep(tt['time'], tt['Rph'], s=0) radii, vels, times = [], [], [] for nt in range(len(tt['time'])): t = tt['time'][nt] if t > t_end: break if t < t_beg or np.abs(t / t_beg < t_diff): continue t_beg = t radius = interpolate.splev(t, Rph_spline) if np.isnan(radius): radius = np.interp(t, tt['time'], tt['Rph'], 0, 0) # One-dimensional linear interpolation. block = res.read_at_time(time=t) if block is None: break if True: vel = np.interp(radius, block['R14'] * 1e14, block['V8'], 0, 0) # One-dimensional linear interpolation. vels.append(vel * 1e8) else: idx = np.abs(block['R14'] - radius / 1e14).argmin() vels.append(block['V8'][idx] * 1e8) radii.append(radius) times.append(t * (1. + z)) # redshifted time # show results res = np.array(np.zeros(len(vels)), dtype=np.dtype({ 'names': ['time', 'vel', 'r'], 'formats': [np.float] * 3 })) res['time'] = times res['vel'] = vels res['r'] = radii return res
def compute_vel_res_tt(name, path, z=0., t_beg=0.1, t_end=None, line_header=80, is_info=False, is_new_std=False): if is_info: print(f'Run model: {name} in dir: {path} z= {z}') model = Stella(name, path=path) # check data if not model.is_res: raise ValueError("There are no res-file for %s in the directory: %s " % (name, path)) if not model.is_tt: raise ValueError(("There are no tt-file for %s in the directory: %s " % (name, path))) if t_end is None: t_end = float('inf') res = model.get_res() tt = model.get_tt().load(line_header=line_header) tt = tt[tt['time'] >= t_beg] # time cut days radii, vels, times = [], [], [] for i, (t, start, end) in enumerate(res.blocks()): if t < min(tt['time']) or t > max(tt['time']): if is_info: print( 'Error: nblock= {}: t_res[={:e}] not in range time_tt: {:e}, {:e}' .format(i, t, min(tt['time']), max(tt['time']))) continue r_ph = np.interp(t, tt['time'], tt['Rph']) # One-dimensional linear interpolation. block = res.read_res_block(start, end, is_new_std=is_new_std) if block is None: break if True: vel = np.interp(r_ph, block['R14'] * 1e14, block['V8'] * 1e8, 0, 0) # One-dimensional linear interpolation. if is_info: # print(' blockR14= {} blockV8= {}'.format(block['R14'], block['V8'])) print('nblock= {} [{}:{}]: t= {:e} r_ph= {:e} vel= {:e}'. format(i, start, end, t, r_ph, vel)) vels.append(vel) else: idx = np.abs(block['R14'] - r_ph / 1e14).argmin() vels.append(block['V8'][idx] * 1e8) radii.append(r_ph) times.append(t * (1. + z)) # redshifted time # show results res = np.array(np.zeros(len(vels)), dtype=np.dtype({ 'names': ['time', 'vel', 'r'], 'formats': [np.float] * 3 })) res['time'] = times res['vel'] = vels res['r'] = radii return res
class TestStellaTt(unittest.TestCase): def setUp(self): # name = 'ccsn2007bi1dNi6smE23bRlDC' name = 'cat_R500_M15_Ni006_E12' # name = 'cat_R1000_M15_Ni007_E15' path = join(dirname(abspath(__file__)), 'data', 'stella') self.stella = Stella(name, path=path) self.tt = self.stella.get_tt() def test_info_parse(self): tt_header = """ <===== HYDRODYNAMIC RUN OF MODEL cat_R1000_M15_Ni007_E15.prf =====> MASS(SOLAR)=15.000 RADIUS(SOLAR)= 1000.000 EXPLOSION ENERGY(10**50 ERG)= 1.50000E+01 <===== =====> INPUT PARAMETERS EPS = 0.00300 Rce = 1.00000E-02 SOL.Rad. HMIN = 1.00000E-11 S AMht = 1.00000E-02 SOL.MASS HMAX = 5.00000E+04 S Tburst= 1.00000E-01 S THRMAT= 1.00000E-30 Ebstht= 1.50000E+01 1e50 ergs METH = 3 CONV = F JSTART= 0 EDTM = T MAXORD= 4 CHNCND= T NSTA = -4500 FitTau= 3.00000E+02 NSTB = -1 TauTol= 1.30000E+00 NOUT = 100 IOUT = -1 TcurA = 0.00000 Rvis = 1.00000 TcurB = 200.00000 BQ = 1.00000E+00 NSTMAX= 360000 DRT = 1.00000E-01 XMNI = 7.00000E-02 SOL.MASS NRT = 1 XNifor= 1.16561E-01 MNicor= 1.16999E-01 SOL.MASS SCAT = T """ pattern = filter(lambda x: len(x) > 0, tt_header.splitlines()) pattern = map(str.strip, pattern) p = r"(.*?)\s*=\s+([-+]?\d*\.\d+|\d+)" for line in pattern: res = re.findall(p, line) if len(res) > 0: for k, v in res: print "key: %s v: %f " % (k, float(v)) def test_info_parse(self): info = self.tt.Info.parse() info.show() tmp = 1000. self.assertEquals(info.R, tmp, "Radius [%f] should be %f" % (info.R, tmp)) tmp = 15. self.assertEquals(info.M, tmp, "Mass [%f] should be %f" % (info.M, tmp)) tmp = 15. self.assertEquals(info.E, tmp, "Ebstht [%f] should be %f" % (info.E, tmp)) def test_tt_vs_ph(self): curves_tt = self.tt.read_curves_tt() bands = curves_tt.BandNames serial_spec = self.stella.read_series_spectrum(t_diff=1.00001) curves_ph = serial_spec.flux_to_curves(bands) models_dic = {'tt': curves_tt, 'ph': curves_ph} lc_types = {'tt': '--', 'ph': '-'} plt.matplotlib.rcParams.update({'font.size': 14}) fig, ax = plt.subplots(1, 1) lcf.plot_models_curves_fixed_bands(ax, models_dic, bands, lc_types=lc_types, ylim=(-10, -24)) # lcf.plot_models_curves(ax, models_dic, bands, lc_types=lc_types, ylim=(-10, -24), xlim=(0, 20)) plt.legend() plt.show() def test_gri_vs_ph(self): curves_gri = self.tt.read_curves_gri() bands = curves_gri.BandNames # bands = ('J','H','K') serial_spec = self.stella.read_series_spectrum(t_diff=1.) curves_ph = serial_spec.flux_to_curves(bands) models_dic = {'gri': curves_gri, 'ph': curves_ph} lc_types = {'gri': '--', 'ph': ':'} plt.matplotlib.rcParams.update({'font.size': 14}) fig, ax = plt.subplots(1, 1) # lcf.plot_models_curves_fixed_bands(ax, models_dic, bands=('r','B','V'), lc_types=lc_types, ylim=(-13, -23), lw=3) lcf.plot_models_curves(ax, models_dic, lc_types=lc_types, ylim=(-10, -23), lw=3) plt.legend() plt.show() def test_tt_vs_gri_vs_ph(self): curves_tt = self.tt.read_curves_tt() bands_tt = curves_tt.BandNames curves_gri = self.tt.read_curves_gri() bands_gri = curves_gri.BandNames bands = np.unique(np.array(bands_tt + bands_gri)) serial_spec = self.stella.read_series_spectrum(t_diff=1.00001) curves_ph = serial_spec.flux_to_curves(bands) models_dic = {'tt': curves_tt, 'gri': curves_gri, 'ph': curves_ph} lc_types = {'tt': ':', 'gri': '--', 'ph': '-'} plt.matplotlib.rcParams.update({'font.size': 14}) fig, ax = plt.subplots(1, 1) lcf.plot_models_curves(ax, models_dic, lc_types=lc_types, ylim=(-10, -19), lw=3) # lcf.plot_models_curves_fixed_bands(ax, models_dic, bands=('B', 'V'), lc_types=lc_types, ylim=(-13, -23), lw=3) plt.legend() plt.show()