def simulation(parameter_filename, simulation_type, find_outputs=False): """ Loads a simulation time series object of the specified simulation type. """ if simulation_type not in simulation_time_series_registry: raise YTSimulationNotIdentified(simulation_type) if os.path.exists(parameter_filename): valid_file = True elif os.path.exists(os.path.join(ytcfg.get("yt", "test_data_dir"), parameter_filename)): parameter_filename = os.path.join(ytcfg.get("yt", "test_data_dir"), parameter_filename) valid_file = True else: valid_file = False if not valid_file: raise YTOutputNotIdentified((parameter_filename, simulation_type), dict(find_outputs=find_outputs)) return simulation_time_series_registry[simulation_type](parameter_filename, find_outputs=find_outputs)
def upload(self): api_key = ytcfg.get("yt", "hub_api_key") url = ytcfg.get("yt", "hub_url") if api_key == '': raise YTHubRegisterError metadata, (final_name, chunks) = self._generate_post() if hasattr(self, "_ds_mrep"): self._ds_mrep.upload() for i in metadata: if isinstance(metadata[i], np.ndarray): metadata[i] = metadata[i].tolist() elif hasattr(metadata[i], 'dtype'): metadata[i] = np.asscalar(metadata[i]) metadata['obj_type'] = self.type if len(chunks) == 0: chunk_info = {'chunks': []} else: chunk_info = {'final_name': final_name, 'chunks': []} for cn, cv in chunks: chunk_info['chunks'].append((cn, cv.size * cv.itemsize)) metadata = json.dumps(metadata) chunk_info = json.dumps(chunk_info) datagen, headers = multipart_encode({'metadata': metadata, 'chunk_info': chunk_info, 'api_key': api_key}) request = urllib.request.Request(url, datagen, headers) # Actually do the request, and get the response try: rv = urllib.request.urlopen(request).read() except urllib.error.HTTPError as ex: if ex.code == 401: mylog.error("You must create an API key before uploading.") mylog.error("https://data.yt-project.org/getting_started.html") return else: raise ex uploader_info = json.loads(rv) new_url = url + "/handler/%s" % uploader_info['handler_uuid'] for i, (cn, cv) in enumerate(chunks): remaining = cv.size * cv.itemsize f = TemporaryFile() np.save(f, cv) f.seek(0) pbar = UploaderBar("%s, % 2i/% 2i" % (self.type, i + 1, len(chunks))) datagen, headers = multipart_encode({'chunk_data': f}, cb=pbar) request = urllib.request.Request(new_url, datagen, headers) rv = urllib.request.urlopen(request).read() datagen, headers = multipart_encode({'status': 'FINAL'}) request = urllib.request.Request(new_url, datagen, headers) rv = json.loads(urllib.request.urlopen(request).read()) mylog.info("Upload succeeded! View here: %s", rv['url']) return rv
def __init__(self, ds, bcdir="", model="chabrier", time_now=None, star_filter=None): self._ds = ds if not os.path.isdir(bcdir): bcdir = os.path.join(ytcfg.get("yt", "test_data_dir"), bcdir) if not os.path.isdir(bcdir): raise RuntimeError("Failed to locate %s" % bcdir) self.bcdir = bcdir self._filter = star_filter self.filter_provided = self._filter is not None if model == "chabrier": self.model = CHABRIER elif model == "salpeter": self.model = SALPETER # Set up for time conversion. self.cosm = Cosmology( hubble_constant=self._ds.hubble_constant, omega_matter=self._ds.omega_matter, omega_lambda=self._ds.omega_lambda) # Find the time right now. if time_now is None: self.time_now = self._ds.current_time else: self.time_now = time_now # Read the tables. self.read_bclib()
def __init__(self, path = None): if path is None: path = ytcfg.get("yt", "enzo_db") if len(path) == 0: raise RuntimeError import sqlite3 self.conn = sqlite3.connect(path)
def __init__(self, outputs, indices, fields=None, suppress_logging=False): indices.sort() # Just in case the caller wasn't careful self.field_data = YTFieldData() if isinstance(outputs, DatasetSeries): self.data_series = outputs else: self.data_series = DatasetSeries(outputs) self.masks = [] self.sorts = [] self.array_indices = [] self.indices = indices self.num_indices = len(indices) self.num_steps = len(outputs) self.times = [] self.suppress_logging = suppress_logging # Default fields if fields is None: fields = [] fields.append("particle_position_x") fields.append("particle_position_y") fields.append("particle_position_z") fields = list(OrderedDict.fromkeys(fields)) if self.suppress_logging: old_level = int(ytcfg.get("yt","loglevel")) mylog.setLevel(40) my_storage = {} pbar = get_pbar("Constructing trajectory information", len(self.data_series)) for i, (sto, ds) in enumerate(self.data_series.piter(storage=my_storage)): dd = ds.all_data() idx_field = dd._determine_fields("particle_index")[0] newtags = dd[idx_field].ndarray_view().astype("int64") mask = np.in1d(newtags, indices, assume_unique=True) sorts = np.argsort(newtags[mask]) self.array_indices.append(np.where(np.in1d(indices, newtags, assume_unique=True))[0]) self.masks.append(mask) self.sorts.append(sorts) sto.result_id = ds.parameter_filename sto.result = ds.current_time pbar.update(i) pbar.finish() if self.suppress_logging: mylog.setLevel(old_level) times = [] for fn, time in sorted(my_storage.items()): times.append(time) self.times = self.data_series[0].arr([time for time in times], times[0].units) self.particle_fields = [] # Instantiate fields the caller requested for field in fields: self._get_data(field)
def _get_data(self, field): """ Get a field to include in the trajectory collection. The trajectory collection itself is a dict of 2D numpy arrays, with shape (num_indices, num_steps) """ if field not in self.field_data: if self.suppress_logging: old_level = int(ytcfg.get("yt","loglevel")) mylog.setLevel(40) ds_first = self.data_series[0] dd_first = ds_first.all_data() fd = dd_first._determine_fields(field)[0] if field not in self.particle_fields: if self.data_series[0].field_info[fd].particle_type: self.particle_fields.append(field) particles = np.empty((self.num_indices,self.num_steps)) particles[:] = np.nan step = int(0) pbar = get_pbar("Generating field %s in trajectories." % (field), self.num_steps) my_storage={} for i, (sto, ds) in enumerate(self.data_series.piter(storage=my_storage)): mask = self.masks[i] sort = self.sorts[i] if field in self.particle_fields: # This is easy... just get the particle fields dd = ds.all_data() pfield = dd[fd].ndarray_view()[mask][sort] else: # This is hard... must loop over grids pfield = np.zeros((self.num_indices)) x = self["particle_position_x"][:,step].ndarray_view() y = self["particle_position_y"][:,step].ndarray_view() z = self["particle_position_z"][:,step].ndarray_view() # This will fail for non-grid index objects particle_grids, particle_grid_inds = ds.index._find_points(x,y,z) for grid in particle_grids: cube = grid.retrieve_ghost_zones(1, [fd]) CICSample_3(x,y,z,pfield, self.num_indices, cube[fd], np.array(grid.LeftEdge).astype(np.float64), np.array(grid.ActiveDimensions).astype(np.int32), grid.dds[0]) sto.result_id = ds.parameter_filename sto.result = (self.array_indices[i], pfield) pbar.update(step) step += 1 pbar.finish() for i, (fn, (indices, pfield)) in enumerate(sorted(my_storage.items())): particles[indices,i] = pfield self.field_data[field] = array_like_field(dd_first, particles, fd) if self.suppress_logging: mylog.setLevel(old_level) return self.field_data[field]
def requires_file(req_file): path = ytcfg.get("yt", "test_data_dir") def ffalse(func): return lambda: None def ftrue(func): return func if os.path.exists(req_file): return ftrue else: if os.path.exists(os.path.join(path,req_file)): return ftrue else: return ffalse
def data_dir_load(ds_fn, cls = None, args = None, kwargs = None): args = args or () kwargs = kwargs or {} path = ytcfg.get("yt", "test_data_dir") if isinstance(ds_fn, Dataset): return ds_fn if not os.path.isdir(path): return False with temp_cwd(path): if cls is None: ds = load(ds_fn, *args, **kwargs) else: ds = cls(ds_fn, *args, **kwargs) ds.index return ds
def can_run_sim(sim_fn, sim_type, file_check = False): if isinstance(sim_fn, SimulationTimeSeries): return AnswerTestingTest.result_storage is not None path = ytcfg.get("yt", "test_data_dir") if not os.path.isdir(path): return False with temp_cwd(path): if file_check: return os.path.isfile(sim_fn) and \ AnswerTestingTest.result_storage is not None try: simulation(sim_fn, sim_type) except YTOutputNotIdentified: return False return AnswerTestingTest.result_storage is not None
def can_run_ds(ds_fn, file_check = False): if isinstance(ds_fn, Dataset): return AnswerTestingTest.result_storage is not None path = ytcfg.get("yt", "test_data_dir") if not os.path.isdir(path): return False with temp_cwd(path): if file_check: return os.path.isfile(ds_fn) and \ AnswerTestingTest.result_storage is not None try: load(ds_fn) except YTOutputNotIdentified: return False return AnswerTestingTest.result_storage is not None
def __call__(self, args): kwargs = {} try: # IPython 1.0+ from IPython.html.notebookapp import NotebookApp except ImportError: # pre-IPython v1.0 from IPython.frontend.html.notebook.notebookapp import NotebookApp print("You must choose a password so that others cannot connect to " \ "your notebook.") pw = ytcfg.get("yt", "notebook_password") if len(pw) == 0 and not args.no_password: import IPython.lib pw = IPython.lib.passwd() print("If you would like to use this password in the future,") print("place a line like this inside the [yt] section in your") print("yt configuration file at ~/.yt/config") print() print("notebook_password = %s" % pw) print() elif args.no_password: pw = None if args.port != 0: kwargs['port'] = int(args.port) if args.profile is not None: kwargs['profile'] = args.profile if pw is not None: kwargs['password'] = pw app = NotebookApp(open_browser=args.open_browser, **kwargs) app.initialize(argv=[]) print() print("***************************************************************") print() print("The notebook is now live at:") print() print(" http://127.0.0.1:%s/" % app.port) print() print("Recall you can create a new SSH tunnel dynamically by pressing") print("~C and then typing -L%s:localhost:%s" % (app.port, app.port)) print("where the first number is the port on your local machine. ") print() print("If you are using %s on your machine already, try " \ "-L8889:localhost:%s" % (app.port, app.port)) print() print("***************************************************************") print() app.start()
def enable_plugins(): import yt from yt.fields.my_plugin_fields import my_plugins_fields from yt.config import ytcfg my_plugin_name = ytcfg.get("yt","pluginfilename") # We assume that it is with respect to the $HOME/.yt directory if os.path.isfile(my_plugin_name): _fn = my_plugin_name else: _fn = os.path.expanduser("~/.yt/%s" % my_plugin_name) if os.path.isfile(_fn): mylog.info("Loading plugins from %s", _fn) execdict = yt.__dict__.copy() execdict['add_field'] = my_plugins_fields.add_field with open(_fn) as f: code = compile(f.read(), _fn, 'exec') exec(code, execdict)
def _get_owls_ion_data_dir(self): txt = "Attempting to download ~ 30 Mb of owls ion data from %s to %s." data_file = "owls_ion_data.tar.gz" data_url = "http://yt-project.org/data" # get test_data_dir from yt config (ytcgf) #---------------------------------------------- tdir = ytcfg.get("yt","test_data_dir") # set download destination to tdir or ./ if tdir isnt defined #---------------------------------------------- if tdir == "/does/not/exist": data_dir = "./" else: data_dir = tdir # check for owls_ion_data directory in data_dir # if not there download the tarball and untar it #---------------------------------------------- owls_ion_path = os.path.join( data_dir, "owls_ion_data" ) if not os.path.exists(owls_ion_path): mylog.info(txt % (data_url, data_dir)) fname = data_dir + "/" + data_file fn = download_file(os.path.join(data_url, data_file), fname) cmnd = "cd " + data_dir + "; " + "tar xf " + data_file os.system(cmnd) if not os.path.exists(owls_ion_path): raise RuntimeError("Failed to download owls ion data.") return owls_ion_path
def prep_dirs(): for directory in glob.glob('%s/*' % ytcfg.get("yt", "test_data_dir")): os.symlink(directory, os.path.basename(directory))
# # The full license is in the file COPYING.txt, distributed with this software. #----------------------------------------------------------------------------- from yt.testing import * from yt.config import ytcfg from yt.analysis_modules.photon_simulator.api import * from yt.utilities.answer_testing.framework import requires_ds, \ GenericArrayTest, data_dir_load import numpy as np def setup(): from yt.config import ytcfg ytcfg["yt", "__withintesting"] = "True" test_dir = ytcfg.get("yt", "test_data_dir") ETC = test_dir+"/enzo_tiny_cosmology/DD0046/DD0046" APEC = test_dir+"/xray_data/atomdb_v2.0.2" TBABS = test_dir+"/xray_data/tbabs_table.h5" ARF = test_dir+"/xray_data/chandra_ACIS-S3_onaxis_arf.fits" RMF = test_dir+"/xray_data/chandra_ACIS-S3_onaxis_rmf.fits" @requires_ds(ETC) @requires_file(APEC) @requires_file(TBABS) @requires_file(ARF) @requires_file(RMF) def test_etc(): np.random.seed(seed=0x4d3d3d3)
from yt.convenience import load, simulation from yt.config import ytcfg from yt.data_objects.static_output import Dataset from yt.data_objects.time_series import SimulationTimeSeries from yt.utilities.logger import disable_stream_logging from yt.utilities.command_line import get_yt_version import matplotlib.image as mpimg import yt.visualization.plot_window as pw import yt.extern.progressbar as progressbar mylog = logging.getLogger('nose.plugins.answer-testing') run_big_data = False # Set the latest gold and local standard filenames _latest = ytcfg.get("yt", "gold_standard_filename") _latest_local = ytcfg.get("yt", "local_standard_filename") _url_path = ytcfg.get("yt", "answer_tests_url") class AnswerTesting(Plugin): name = "answer-testing" _my_version = None def options(self, parser, env=os.environ): super(AnswerTesting, self).options(parser, env=env) parser.add_option("--answer-name", dest="answer_name", metavar='str', default=None, help="The name of the standard to store/compare against") parser.add_option("--answer-store", dest="store_results", metavar='bool', default=False, action="store_true", help="Should we store this result instead of comparing?") parser.add_option("--local", dest="local_results",
def _get_db_name(self): base_file_name = ytcfg.get("yt", "ParameterFileStore") if not os.access(os.path.expanduser("~/"), os.W_OK): return os.path.abspath(base_file_name) return os.path.expanduser("~/.yt/%s" % base_file_name)
def __call__(self, args): # We need these pieces of information: # 1. Name # 2. Email # 3. Username # 4. Password (and password2) # 5. (optional) URL # 6. "Secret" key to make it epsilon harder for spammers if ytcfg.get("yt","hub_api_key") != "": print("You seem to already have an API key for the hub in") print("~/.yt/config . Delete this if you want to force a") print("new user registration.") print("Awesome! Let's start by registering a new user for you.") print("Here's the URL, for reference: http://hub.yt-project.org/ ") print() print("As always, bail out with Ctrl-C at any time.") print() print("What username would you like to go by?") print() username = raw_input("Username? ") if len(username) == 0: sys.exit(1) print() print("To start out, what's your name?") print() name = raw_input("Name? ") if len(name) == 0: sys.exit(1) print() print("And your email address?") print() email = raw_input("Email? ") if len(email) == 0: sys.exit(1) print() print("Please choose a password:"******"Password? ") password2 = getpass.getpass("Confirm? ") if len(password1) == 0: continue if password1 == password2: break print("Sorry, they didn't match! Let's try again.") print() print() print("Would you like a URL displayed for your user?") print("Leave blank if no.") print() url = raw_input("URL? ") print() print("Okay, press enter to register. You should receive a welcome") print("message at %s when this is complete." % email) print() loki = raw_input() data = dict(name = name, email = email, username = username, password = password1, password2 = password2, url = url, zap = "rowsdower") data = urllib.parse.urlencode(data) hub_url = "https://hub.yt-project.org/create_user" req = urllib.request.Request(hub_url, data) try: status = urllib.request.urlopen(req).read() except urllib.error.HTTPError as exc: if exc.code == 400: print("Sorry, the Hub couldn't create your user.") print("You can't register duplicate users, which is the most") print("common cause of this error. All values for username,") print("name, and email must be unique in our system.") sys.exit(1) except urllib.URLError as exc: print("Something has gone wrong. Here's the error message.") raise exc print() print("SUCCESS!") print()
# '{mol}_surf'.format(mol=mol, level=level, # transparency=transparency)) #surf.export_obj(filename, transparency=transparency, # color_field=ytcubes[mol].dataset.field_list[0], # #color_map=colors[mol], # plot_index=ii) filename = os.path.join('IsoSurfs', 'all_surfs') surf.export_obj(filename, transparency=transparency, color_field='ones', #ytcubes[mol].dataset.field_list[0], color_map=colornames[mol], plot_index=jj+ii*len(surfaces)) import zipfile zfn = 'IsoSurfs/all_surfs12.zip' zf = zipfile.ZipFile(zfn, mode='w') zf.write('IsoSurfs/all_surfs.obj') zf.write('IsoSurfs/all_surfs.mtl') zf.close() from yt.config import ytcfg api_key = ytcfg.get("yt","sketchfab_api_key") import requests import os data = {'title': 'Sgr B2 meshes colored (try 5)', 'token': api_key, 'fileModel': zfn, 'filenameModel': os.path.basename(zfn)} response = requests.post('https://api.sketchfab.com/v1/models', files=data) response.raise_for_status()
def load(*args ,**kwargs): """ This function attempts to determine the base data type of a filename or other set of arguments by calling :meth:`yt.data_objects.api.Dataset._is_valid` until it finds a match, at which point it returns an instance of the appropriate :class:`yt.data_objects.api.Dataset` subclass. """ if len(args) == 0: try: from yt.extern.six.moves import tkinter import tkinter, tkFileDialog except ImportError: raise YTOutputNotIdentified(args, kwargs) root = tkinter.Tk() filename = tkFileDialog.askopenfilename(parent=root,title='Choose a file') if filename != None: return load(filename) else: raise YTOutputNotIdentified(args, kwargs) candidates = [] args = [os.path.expanduser(arg) if isinstance(arg, str) else arg for arg in args] valid_file = [] for argno, arg in enumerate(args): if isinstance(arg, str): if os.path.exists(arg): valid_file.append(True) elif arg.startswith("http"): valid_file.append(True) else: if os.path.exists(os.path.join(ytcfg.get("yt", "test_data_dir"), arg)): valid_file.append(True) args[argno] = os.path.join(ytcfg.get("yt", "test_data_dir"), arg) else: valid_file.append(False) else: valid_file.append(False) if not any(valid_file): try: from yt.data_objects.time_series import DatasetSeries ts = DatasetSeries.from_filenames(*args, **kwargs) return ts except YTOutputNotIdentified: pass mylog.error("None of the arguments provided to load() is a valid file") mylog.error("Please check that you have used a correct path") raise YTOutputNotIdentified(args, kwargs) for n, c in output_type_registry.items(): if n is None: continue if c._is_valid(*args, **kwargs): candidates.append(n) # convert to classes candidates = [output_type_registry[c] for c in candidates] # Find only the lowest subclasses, i.e. most specialised front ends candidates = find_lowest_subclasses(candidates) if len(candidates) == 1: return candidates[0](*args, **kwargs) if len(candidates) == 0: if ytcfg.get("yt", "enzo_db") != '' \ and len(args) == 1 \ and isinstance(args[0], str): erdb = EnzoRunDatabase() fn = erdb.find_uuid(args[0]) n = "EnzoDataset" if n in output_type_registry \ and output_type_registry[n]._is_valid(fn): return output_type_registry[n](fn) mylog.error("Couldn't figure out output type for %s", args[0]) raise YTOutputNotIdentified(args, kwargs) mylog.error("Multiple output type candidates for %s:", args[0]) for c in candidates: mylog.error(" Possible: %s", c) raise YTOutputNotIdentified(args, kwargs)
def write_image(self, filename, color_bounds=None, channel=None, cmap_name=None, func=lambda x: x): r"""Writes a single channel of the ImageArray to a png file. Parameters ---------- filename: string Note filename not be modified. Other Parameters ---------------- channel: int Which channel to write out as an image. Defaults to 0 cmap_name: string Name of the colormap to be used. color_bounds : tuple of floats, optional The min and max to scale between. Outlying values will be clipped. cmap_name : string, optional An acceptable colormap. See either yt.visualization.color_maps or http://www.scipy.org/Cookbook/Matplotlib/Show_colormaps . func : function, optional A function to transform the buffer before applying a colormap. Returns ------- scaled_image : uint8 image that has been saved Examples -------- >>> im = np.zeros([64,128]) >>> for i in range(im.shape[0]): ... im[i,:] = np.linspace(0.,0.3*i, im.shape[1]) >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]), ... 'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]), ... 'width':0.245, 'units':'cm', 'type':'rendering'} >>> im_arr = ImageArray(im, info=myinfo) >>> im_arr.write_image('test_ImageArray.png') """ if cmap_name is None: cmap_name = ytcfg.get("yt", "default_colormap") if filename is not None and filename[-4:] != '.png': filename += '.png' #TODO: Write info dict as png metadata if channel is None: return write_image(self.swapaxes(0, 1).to_ndarray(), filename, color_bounds=color_bounds, cmap_name=cmap_name, func=func) else: return write_image(self.swapaxes(0, 1)[:, :, channel].to_ndarray(), filename, color_bounds=color_bounds, cmap_name=cmap_name, func=func)
def __init__(self, filename, dataset_type='fits', auxiliary_files=[], nprocs=None, storage_filename=None, nan_mask=None, spectral_factor=1.0, z_axis_decomp=False, suppress_astropy_warnings=True, parameters=None, units_override=None): if parameters is None: parameters = {} parameters["nprocs"] = nprocs self.specified_parameters = parameters self.z_axis_decomp = z_axis_decomp self.spectral_factor = spectral_factor if suppress_astropy_warnings: warnings.filterwarnings('ignore', module="astropy", append=True) auxiliary_files = ensure_list(auxiliary_files) self.filenames = [filename] + auxiliary_files self.num_files = len(self.filenames) self.fluid_types += ("fits",) if nan_mask is None: self.nan_mask = {} elif isinstance(nan_mask, float): self.nan_mask = {"all":nan_mask} elif isinstance(nan_mask, dict): self.nan_mask = nan_mask self._handle = FITSFileHandler(self.filenames[0]) if (isinstance(self.filenames[0], _astropy.pyfits.hdu.image._ImageBaseHDU) or isinstance(self.filenames[0], _astropy.pyfits.HDUList)): fn = "InMemoryFITSFile_%s" % uuid.uuid4().hex else: fn = self.filenames[0] self._handle._fits_files.append(self._handle) if self.num_files > 1: for fits_file in auxiliary_files: if isinstance(fits_file, _astropy.pyfits.hdu.image._ImageBaseHDU): f = _astropy.pyfits.HDUList([fits_file]) elif isinstance(fits_file, _astropy.pyfits.HDUList): f = fits_file else: if os.path.exists(fits_file): fn = fits_file else: fn = os.path.join(ytcfg.get("yt","test_data_dir"),fits_file) f = _astropy.pyfits.open(fn, memmap=True, do_not_scale_image_data=True, ignore_blank=True) self._handle._fits_files.append(f) if len(self._handle) > 1 and self._handle[1].name == "EVENTS": self.events_data = True self.first_image = 1 self.primary_header = self._handle[self.first_image].header self.naxis = 2 self.wcs = _astropy.pywcs.WCS(naxis=2) self.events_info = {} for k,v in self.primary_header.items(): if k.startswith("TTYP"): if v.lower() in ["x","y"]: num = k.strip("TTYPE") self.events_info[v.lower()] = (self.primary_header["TLMIN"+num], self.primary_header["TLMAX"+num], self.primary_header["TCTYP"+num], self.primary_header["TCRVL"+num], self.primary_header["TCDLT"+num], self.primary_header["TCRPX"+num]) elif v.lower() in ["energy","time"]: num = k.strip("TTYPE") unit = self.primary_header["TUNIT"+num].lower() if unit.endswith("ev"): unit = unit.replace("ev","eV") self.events_info[v.lower()] = unit self.axis_names = [self.events_info[ax][2] for ax in ["x","y"]] self.reblock = 1 if "reblock" in self.specified_parameters: self.reblock = self.specified_parameters["reblock"] self.wcs.wcs.cdelt = [self.events_info["x"][4]*self.reblock, self.events_info["y"][4]*self.reblock] self.wcs.wcs.crpix = [(self.events_info["x"][5]-0.5)/self.reblock+0.5, (self.events_info["y"][5]-0.5)/self.reblock+0.5] self.wcs.wcs.ctype = [self.events_info["x"][2],self.events_info["y"][2]] self.wcs.wcs.cunit = ["deg","deg"] self.wcs.wcs.crval = [self.events_info["x"][3],self.events_info["y"][3]] self.dims = [(self.events_info["x"][1]-self.events_info["x"][0])/self.reblock, (self.events_info["y"][1]-self.events_info["y"][0])/self.reblock] else: self.events_data = False # Sometimes the primary hdu doesn't have an image if len(self._handle) > 1 and self._handle[0].header["naxis"] == 0: self.first_image = 1 else: self.first_image = 0 self.primary_header = self._handle[self.first_image].header self.naxis = self.primary_header["naxis"] self.axis_names = [self.primary_header.get("ctype%d" % (i+1),"LINEAR") for i in range(self.naxis)] self.dims = [self.primary_header["naxis%d" % (i+1)] for i in range(self.naxis)] wcs = _astropy.pywcs.WCS(header=self.primary_header) if self.naxis == 4: self.wcs = _astropy.pywcs.WCS(naxis=3) self.wcs.wcs.crpix = wcs.wcs.crpix[:3] self.wcs.wcs.cdelt = wcs.wcs.cdelt[:3] self.wcs.wcs.crval = wcs.wcs.crval[:3] self.wcs.wcs.cunit = [str(unit) for unit in wcs.wcs.cunit][:3] self.wcs.wcs.ctype = [type for type in wcs.wcs.ctype][:3] else: self.wcs = wcs self.refine_by = 2 Dataset.__init__(self, fn, dataset_type, units_override=units_override) self.storage_filename = storage_filename
def __init__(self, outputs, indices, fields=None, suppress_logging=False, ptype=None): indices.sort() # Just in case the caller wasn't careful self.field_data = YTFieldData() self.data_series = outputs self.masks = [] self.sorts = [] self.array_indices = [] self.indices = indices self.num_indices = len(indices) self.num_steps = len(outputs) self.times = [] self.suppress_logging = suppress_logging self.ptype = ptype if ptype else "all" if fields is None: fields = [] if self.suppress_logging: old_level = int(ytcfg.get("yt", "log_level")) mylog.setLevel(40) ds_first = self.data_series[0] dd_first = ds_first.all_data() fds = {} for field in ( "particle_index", "particle_position_x", "particle_position_y", "particle_position_z", ): fds[field] = dd_first._determine_fields((self.ptype, field))[0] # Note: we explicitly pass dynamic=False to prevent any change in piter from # breaking the assumption that the same processors load the same datasets my_storage = {} pbar = get_pbar("Constructing trajectory information", len(self.data_series)) for i, (sto, ds) in enumerate( self.data_series.piter(storage=my_storage, dynamic=False)): dd = ds.all_data() newtags = dd[fds["particle_index"]].d.astype("int64") mask = np.in1d(newtags, indices, assume_unique=True) sort = np.argsort(newtags[mask]) array_indices = np.where( np.in1d(indices, newtags, assume_unique=True))[0] self.array_indices.append(array_indices) self.masks.append(mask) self.sorts.append(sort) pfields = {} for field in (f"particle_position_{ax}" for ax in "xyz"): pfields[field] = dd[fds[field]].ndarray_view()[mask][sort] sto.result_id = ds.parameter_filename sto.result = (ds.current_time, array_indices, pfields) pbar.update(i + 1) pbar.finish() if self.suppress_logging: mylog.setLevel(old_level) sorted_storage = sorted(my_storage.items()) _fn, (time, *_) = sorted_storage[0] time_units = time.units times = [time.to(time_units) for _fn, (time, *_) in sorted_storage] self.times = self.data_series[0].arr([time.value for time in times], time_units) self.particle_fields = [] output_field = np.empty((self.num_indices, self.num_steps)) output_field.fill(np.nan) for field in (f"particle_position_{ax}" for ax in "xyz"): for i, (_fn, (_time, indices, pfields)) in enumerate(sorted_storage): try: # This will fail if particles ids are # duplicate. This is due to the fact that the rhs # would then have a different shape as the lhs output_field[indices, i] = pfields[field] except ValueError as e: raise YTIllDefinedParticleData( "This dataset contains duplicate particle indices!" ) from e self.field_data[field] = array_like_field(dd_first, output_field.copy(), fds[field]) self.particle_fields.append(field) # Instantiate fields the caller requested self._get_data(fields)