def simulation(parameter_filename, simulation_type, find_outputs=False):
    """
    Loads a simulation time series object of the specified
    simulation type.
    """

    if simulation_type not in simulation_time_series_registry:
        raise YTSimulationNotIdentified(simulation_type)

    if os.path.exists(parameter_filename):
        valid_file = True
    elif os.path.exists(os.path.join(ytcfg.get("yt", "test_data_dir"),
                                     parameter_filename)):
        parameter_filename = os.path.join(ytcfg.get("yt", "test_data_dir"),
                                          parameter_filename)
        valid_file = True
    else:
        valid_file = False

    if not valid_file:
        raise YTOutputNotIdentified((parameter_filename, simulation_type),
                                    dict(find_outputs=find_outputs))

    return simulation_time_series_registry[simulation_type](parameter_filename,
                                                            find_outputs=find_outputs)
    def upload(self):
        api_key = ytcfg.get("yt", "hub_api_key")
        url = ytcfg.get("yt", "hub_url")
        if api_key == '':
            raise YTHubRegisterError
        metadata, (final_name, chunks) = self._generate_post()
        if hasattr(self, "_ds_mrep"):
            self._ds_mrep.upload()
        for i in metadata:
            if isinstance(metadata[i], np.ndarray):
                metadata[i] = metadata[i].tolist()
            elif hasattr(metadata[i], 'dtype'):
                metadata[i] = np.asscalar(metadata[i])
        metadata['obj_type'] = self.type
        if len(chunks) == 0:
            chunk_info = {'chunks': []}
        else:
            chunk_info = {'final_name': final_name, 'chunks': []}
            for cn, cv in chunks:
                chunk_info['chunks'].append((cn, cv.size * cv.itemsize))
        metadata = json.dumps(metadata)
        chunk_info = json.dumps(chunk_info)
        datagen, headers = multipart_encode({'metadata': metadata,
                                             'chunk_info': chunk_info,
                                             'api_key': api_key})
        request = urllib.request.Request(url, datagen, headers)
        # Actually do the request, and get the response
        try:
            rv = urllib.request.urlopen(request).read()
        except urllib.error.HTTPError as ex:
            if ex.code == 401:
                mylog.error("You must create an API key before uploading.")
                mylog.error("https://data.yt-project.org/getting_started.html")
                return
            else:
                raise ex
        uploader_info = json.loads(rv)
        new_url = url + "/handler/%s" % uploader_info['handler_uuid']
        for i, (cn, cv) in enumerate(chunks):
            remaining = cv.size * cv.itemsize
            f = TemporaryFile()
            np.save(f, cv)
            f.seek(0)
            pbar = UploaderBar("%s, % 2i/% 2i" %
                               (self.type, i + 1, len(chunks)))
            datagen, headers = multipart_encode({'chunk_data': f}, cb=pbar)
            request = urllib.request.Request(new_url, datagen, headers)
            rv = urllib.request.urlopen(request).read()

        datagen, headers = multipart_encode({'status': 'FINAL'})
        request = urllib.request.Request(new_url, datagen, headers)
        rv = json.loads(urllib.request.urlopen(request).read())
        mylog.info("Upload succeeded!  View here: %s", rv['url'])
        return rv
    def __init__(self, ds, bcdir="", model="chabrier", time_now=None,
                 star_filter=None):
        self._ds = ds
        if not os.path.isdir(bcdir):
            bcdir = os.path.join(ytcfg.get("yt", "test_data_dir"), bcdir)
            if not os.path.isdir(bcdir):
                raise RuntimeError("Failed to locate %s" % bcdir)
        self.bcdir = bcdir
        self._filter = star_filter
        self.filter_provided = self._filter is not None
        if model == "chabrier":
            self.model = CHABRIER
        elif model == "salpeter":
            self.model = SALPETER
        # Set up for time conversion.
        self.cosm = Cosmology(
            hubble_constant=self._ds.hubble_constant,
            omega_matter=self._ds.omega_matter,
            omega_lambda=self._ds.omega_lambda)
        # Find the time right now.

        if time_now is None:
            self.time_now = self._ds.current_time
        else:
            self.time_now = time_now

        # Read the tables.
        self.read_bclib()
 def __init__(self, path = None):
     if path is None:
         path = ytcfg.get("yt", "enzo_db")
         if len(path) == 0:
             raise RuntimeError
     import sqlite3
     self.conn = sqlite3.connect(path)
    def __init__(self, outputs, indices, fields=None, suppress_logging=False):

        indices.sort() # Just in case the caller wasn't careful
        self.field_data = YTFieldData()
        if isinstance(outputs, DatasetSeries):
            self.data_series = outputs
        else:
            self.data_series = DatasetSeries(outputs)
        self.masks = []
        self.sorts = []
        self.array_indices = []
        self.indices = indices
        self.num_indices = len(indices)
        self.num_steps = len(outputs)
        self.times = []
        self.suppress_logging = suppress_logging

        # Default fields 
        
        if fields is None: fields = []
        fields.append("particle_position_x")
        fields.append("particle_position_y")
        fields.append("particle_position_z")
        fields = list(OrderedDict.fromkeys(fields))

        if self.suppress_logging:
            old_level = int(ytcfg.get("yt","loglevel"))
            mylog.setLevel(40)
        my_storage = {}
        pbar = get_pbar("Constructing trajectory information", len(self.data_series))
        for i, (sto, ds) in enumerate(self.data_series.piter(storage=my_storage)):
            dd = ds.all_data()
            idx_field = dd._determine_fields("particle_index")[0]
            newtags = dd[idx_field].ndarray_view().astype("int64")
            mask = np.in1d(newtags, indices, assume_unique=True)
            sorts = np.argsort(newtags[mask])
            self.array_indices.append(np.where(np.in1d(indices, newtags, assume_unique=True))[0])
            self.masks.append(mask)
            self.sorts.append(sorts)
            sto.result_id = ds.parameter_filename
            sto.result = ds.current_time
            pbar.update(i)
        pbar.finish()

        if self.suppress_logging:
            mylog.setLevel(old_level)

        times = []
        for fn, time in sorted(my_storage.items()):
            times.append(time)

        self.times = self.data_series[0].arr([time for time in times], times[0].units)

        self.particle_fields = []

        # Instantiate fields the caller requested

        for field in fields:
            self._get_data(field)
 def _get_data(self, field):
     """
     Get a field to include in the trajectory collection.
     The trajectory collection itself is a dict of 2D numpy arrays,
     with shape (num_indices, num_steps)
     """
     if field not in self.field_data:
         if self.suppress_logging:
             old_level = int(ytcfg.get("yt","loglevel"))
             mylog.setLevel(40)
         ds_first = self.data_series[0]
         dd_first = ds_first.all_data()
         fd = dd_first._determine_fields(field)[0]
         if field not in self.particle_fields:
             if self.data_series[0].field_info[fd].particle_type:
                 self.particle_fields.append(field)
         particles = np.empty((self.num_indices,self.num_steps))
         particles[:] = np.nan
         step = int(0)
         pbar = get_pbar("Generating field %s in trajectories." % (field), self.num_steps)
         my_storage={}
         for i, (sto, ds) in enumerate(self.data_series.piter(storage=my_storage)):
             mask = self.masks[i]
             sort = self.sorts[i]
             if field in self.particle_fields:
                 # This is easy... just get the particle fields
                 dd = ds.all_data()
                 pfield = dd[fd].ndarray_view()[mask][sort]
             else:
                 # This is hard... must loop over grids
                 pfield = np.zeros((self.num_indices))
                 x = self["particle_position_x"][:,step].ndarray_view()
                 y = self["particle_position_y"][:,step].ndarray_view()
                 z = self["particle_position_z"][:,step].ndarray_view()
                 # This will fail for non-grid index objects
                 particle_grids, particle_grid_inds = ds.index._find_points(x,y,z)
                 for grid in particle_grids:
                     cube = grid.retrieve_ghost_zones(1, [fd])
                     CICSample_3(x,y,z,pfield,
                                 self.num_indices,
                                 cube[fd],
                                 np.array(grid.LeftEdge).astype(np.float64),
                                 np.array(grid.ActiveDimensions).astype(np.int32),
                                 grid.dds[0])
             sto.result_id = ds.parameter_filename
             sto.result = (self.array_indices[i], pfield)
             pbar.update(step)
             step += 1
         pbar.finish()
         for i, (fn, (indices, pfield)) in enumerate(sorted(my_storage.items())):
             particles[indices,i] = pfield
         self.field_data[field] = array_like_field(dd_first, particles, fd)
         if self.suppress_logging:
             mylog.setLevel(old_level)
     return self.field_data[field]
Exemplo n.º 7
0
def requires_file(req_file):
    path = ytcfg.get("yt", "test_data_dir")
    def ffalse(func):
        return lambda: None
    def ftrue(func):
        return func
    if os.path.exists(req_file):
        return ftrue
    else:
        if os.path.exists(os.path.join(path,req_file)):
            return ftrue
        else:
            return ffalse
Exemplo n.º 8
0
def data_dir_load(ds_fn, cls = None, args = None, kwargs = None):
    args = args or ()
    kwargs = kwargs or {}
    path = ytcfg.get("yt", "test_data_dir")
    if isinstance(ds_fn, Dataset): return ds_fn
    if not os.path.isdir(path):
        return False
    with temp_cwd(path):
        if cls is None:
            ds = load(ds_fn, *args, **kwargs)
        else:
            ds = cls(ds_fn, *args, **kwargs)
        ds.index
        return ds
Exemplo n.º 9
0
def can_run_sim(sim_fn, sim_type, file_check = False):
    if isinstance(sim_fn, SimulationTimeSeries):
        return AnswerTestingTest.result_storage is not None
    path = ytcfg.get("yt", "test_data_dir")
    if not os.path.isdir(path):
        return False
    with temp_cwd(path):
        if file_check:
            return os.path.isfile(sim_fn) and \
                AnswerTestingTest.result_storage is not None
        try:
            simulation(sim_fn, sim_type)
        except YTOutputNotIdentified:
            return False
    return AnswerTestingTest.result_storage is not None
Exemplo n.º 10
0
def can_run_ds(ds_fn, file_check = False):
    if isinstance(ds_fn, Dataset):
        return AnswerTestingTest.result_storage is not None
    path = ytcfg.get("yt", "test_data_dir")
    if not os.path.isdir(path):
        return False
    with temp_cwd(path):
        if file_check:
            return os.path.isfile(ds_fn) and \
                AnswerTestingTest.result_storage is not None
        try:
            load(ds_fn)
        except YTOutputNotIdentified:
            return False
    return AnswerTestingTest.result_storage is not None
 def __call__(self, args):
     kwargs = {}
     try:
         # IPython 1.0+
         from IPython.html.notebookapp import NotebookApp
     except ImportError:
         # pre-IPython v1.0
         from IPython.frontend.html.notebook.notebookapp import NotebookApp
     print("You must choose a password so that others cannot connect to " \
           "your notebook.")
     pw = ytcfg.get("yt", "notebook_password")
     if len(pw) == 0 and not args.no_password:
         import IPython.lib
         pw = IPython.lib.passwd()
         print("If you would like to use this password in the future,")
         print("place a line like this inside the [yt] section in your")
         print("yt configuration file at ~/.yt/config")
         print()
         print("notebook_password = %s" % pw)
         print()
     elif args.no_password:
         pw = None
     if args.port != 0:
         kwargs['port'] = int(args.port)
     if args.profile is not None:
         kwargs['profile'] = args.profile
     if pw is not None:
         kwargs['password'] = pw
     app = NotebookApp(open_browser=args.open_browser,
                       **kwargs)
     app.initialize(argv=[])
     print()
     print("***************************************************************")
     print()
     print("The notebook is now live at:")
     print()
     print("     http://127.0.0.1:%s/" % app.port)
     print()
     print("Recall you can create a new SSH tunnel dynamically by pressing")
     print("~C and then typing -L%s:localhost:%s" % (app.port, app.port))
     print("where the first number is the port on your local machine. ")
     print()
     print("If you are using %s on your machine already, try " \
           "-L8889:localhost:%s" % (app.port, app.port))
     print()
     print("***************************************************************")
     print()
     app.start()
Exemplo n.º 12
0
def enable_plugins():
    import yt
    from yt.fields.my_plugin_fields import my_plugins_fields
    from yt.config import ytcfg
    my_plugin_name = ytcfg.get("yt","pluginfilename")
    # We assume that it is with respect to the $HOME/.yt directory
    if os.path.isfile(my_plugin_name):
        _fn = my_plugin_name
    else:
        _fn = os.path.expanduser("~/.yt/%s" % my_plugin_name)
    if os.path.isfile(_fn):
        mylog.info("Loading plugins from %s", _fn)
        execdict = yt.__dict__.copy()
        execdict['add_field'] = my_plugins_fields.add_field
        with open(_fn) as f:
            code = compile(f.read(), _fn, 'exec')
            exec(code, execdict)
Exemplo n.º 13
0
    def _get_owls_ion_data_dir(self):

        txt = "Attempting to download ~ 30 Mb of owls ion data from %s to %s."
        data_file = "owls_ion_data.tar.gz"
        data_url = "http://yt-project.org/data"

        # get test_data_dir from yt config (ytcgf)
        #----------------------------------------------
        tdir = ytcfg.get("yt","test_data_dir")

        # set download destination to tdir or ./ if tdir isnt defined
        #----------------------------------------------
        if tdir == "/does/not/exist":
            data_dir = "./"
        else:
            data_dir = tdir            


        # check for owls_ion_data directory in data_dir
        # if not there download the tarball and untar it
        #----------------------------------------------
        owls_ion_path = os.path.join( data_dir, "owls_ion_data" )

        if not os.path.exists(owls_ion_path):
            mylog.info(txt % (data_url, data_dir))                    
            fname = data_dir + "/" + data_file
            fn = download_file(os.path.join(data_url, data_file), fname)

            cmnd = "cd " + data_dir + "; " + "tar xf " + data_file
            os.system(cmnd)


        if not os.path.exists(owls_ion_path):
            raise RuntimeError("Failed to download owls ion data.")

        return owls_ion_path
Exemplo n.º 14
0
def prep_dirs():
    for directory in glob.glob('%s/*' % ytcfg.get("yt", "test_data_dir")):
        os.symlink(directory, os.path.basename(directory))
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------

from yt.testing import *
from yt.config import ytcfg
from yt.analysis_modules.photon_simulator.api import *
from yt.utilities.answer_testing.framework import requires_ds, \
     GenericArrayTest, data_dir_load
import numpy as np

def setup():
    from yt.config import ytcfg
    ytcfg["yt", "__withintesting"] = "True"

test_dir = ytcfg.get("yt", "test_data_dir")

ETC = test_dir+"/enzo_tiny_cosmology/DD0046/DD0046"
APEC = test_dir+"/xray_data/atomdb_v2.0.2"
TBABS = test_dir+"/xray_data/tbabs_table.h5"
ARF = test_dir+"/xray_data/chandra_ACIS-S3_onaxis_arf.fits"
RMF = test_dir+"/xray_data/chandra_ACIS-S3_onaxis_rmf.fits"

@requires_ds(ETC)
@requires_file(APEC)
@requires_file(TBABS)
@requires_file(ARF)
@requires_file(RMF)
def test_etc():

    np.random.seed(seed=0x4d3d3d3)
Exemplo n.º 16
0
from yt.convenience import load, simulation
from yt.config import ytcfg
from yt.data_objects.static_output import Dataset
from yt.data_objects.time_series import SimulationTimeSeries
from yt.utilities.logger import disable_stream_logging
from yt.utilities.command_line import get_yt_version

import matplotlib.image as mpimg
import yt.visualization.plot_window as pw
import yt.extern.progressbar as progressbar

mylog = logging.getLogger('nose.plugins.answer-testing')
run_big_data = False

# Set the latest gold and local standard filenames
_latest = ytcfg.get("yt", "gold_standard_filename")
_latest_local = ytcfg.get("yt", "local_standard_filename")
_url_path = ytcfg.get("yt", "answer_tests_url")

class AnswerTesting(Plugin):
    name = "answer-testing"
    _my_version = None

    def options(self, parser, env=os.environ):
        super(AnswerTesting, self).options(parser, env=env)
        parser.add_option("--answer-name", dest="answer_name", metavar='str',
            default=None, help="The name of the standard to store/compare against")
        parser.add_option("--answer-store", dest="store_results", metavar='bool',
            default=False, action="store_true",
            help="Should we store this result instead of comparing?")
        parser.add_option("--local", dest="local_results",
 def _get_db_name(self):
     base_file_name = ytcfg.get("yt", "ParameterFileStore")
     if not os.access(os.path.expanduser("~/"), os.W_OK):
         return os.path.abspath(base_file_name)
     return os.path.expanduser("~/.yt/%s" % base_file_name)
 def __call__(self, args):
     # We need these pieces of information:
     #   1. Name
     #   2. Email
     #   3. Username
     #   4. Password (and password2)
     #   5. (optional) URL
     #   6. "Secret" key to make it epsilon harder for spammers
     if ytcfg.get("yt","hub_api_key") != "":
         print("You seem to already have an API key for the hub in")
         print("~/.yt/config .  Delete this if you want to force a")
         print("new user registration.")
     print("Awesome!  Let's start by registering a new user for you.")
     print("Here's the URL, for reference: http://hub.yt-project.org/ ")
     print()
     print("As always, bail out with Ctrl-C at any time.")
     print()
     print("What username would you like to go by?")
     print()
     username = raw_input("Username? ")
     if len(username) == 0: sys.exit(1)
     print()
     print("To start out, what's your name?")
     print()
     name = raw_input("Name? ")
     if len(name) == 0: sys.exit(1)
     print()
     print("And your email address?")
     print()
     email = raw_input("Email? ")
     if len(email) == 0: sys.exit(1)
     print()
     print("Please choose a password:"******"Password? ")
         password2 = getpass.getpass("Confirm? ")
         if len(password1) == 0: continue
         if password1 == password2: break
         print("Sorry, they didn't match!  Let's try again.")
         print()
     print()
     print("Would you like a URL displayed for your user?")
     print("Leave blank if no.")
     print()
     url = raw_input("URL? ")
     print()
     print("Okay, press enter to register.  You should receive a welcome")
     print("message at %s when this is complete." % email)
     print()
     loki = raw_input()
     data = dict(name = name, email = email, username = username,
                 password = password1, password2 = password2,
                 url = url, zap = "rowsdower")
     data = urllib.parse.urlencode(data)
     hub_url = "https://hub.yt-project.org/create_user"
     req = urllib.request.Request(hub_url, data)
     try:
         status = urllib.request.urlopen(req).read()
     except urllib.error.HTTPError as exc:
         if exc.code == 400:
             print("Sorry, the Hub couldn't create your user.")
             print("You can't register duplicate users, which is the most")
             print("common cause of this error.  All values for username,")
             print("name, and email must be unique in our system.")
             sys.exit(1)
     except urllib.URLError as exc:
         print("Something has gone wrong.  Here's the error message.")
         raise exc
     print()
     print("SUCCESS!")
     print()
Exemplo n.º 19
0
        #                        '{mol}_surf'.format(mol=mol, level=level,
        #                                            transparency=transparency))
        #surf.export_obj(filename, transparency=transparency,
        #                color_field=ytcubes[mol].dataset.field_list[0],
        #                #color_map=colors[mol],
        #                plot_index=ii)
        filename = os.path.join('IsoSurfs',
                                'all_surfs')
        surf.export_obj(filename, transparency=transparency,
                        color_field='ones', #ytcubes[mol].dataset.field_list[0],
                        color_map=colornames[mol],
                        plot_index=jj+ii*len(surfaces))

import zipfile
zfn = 'IsoSurfs/all_surfs12.zip'
zf = zipfile.ZipFile(zfn, mode='w')
zf.write('IsoSurfs/all_surfs.obj')
zf.write('IsoSurfs/all_surfs.mtl')
zf.close()

from yt.config import ytcfg
api_key = ytcfg.get("yt","sketchfab_api_key")
import requests
import os
data = {'title': 'Sgr B2 meshes colored (try 5)',
        'token': api_key,
        'fileModel': zfn,
        'filenameModel': os.path.basename(zfn)}
response = requests.post('https://api.sketchfab.com/v1/models', files=data)
response.raise_for_status()
Exemplo n.º 20
0
def load(*args ,**kwargs):
    """
    This function attempts to determine the base data type of a filename or
    other set of arguments by calling
    :meth:`yt.data_objects.api.Dataset._is_valid` until it finds a
    match, at which point it returns an instance of the appropriate
    :class:`yt.data_objects.api.Dataset` subclass.
    """
    if len(args) == 0:
        try:
            from yt.extern.six.moves import tkinter
            import tkinter, tkFileDialog
        except ImportError:
            raise YTOutputNotIdentified(args, kwargs)
        root = tkinter.Tk()
        filename = tkFileDialog.askopenfilename(parent=root,title='Choose a file')
        if filename != None:
            return load(filename)
        else:
            raise YTOutputNotIdentified(args, kwargs)
    candidates = []
    args = [os.path.expanduser(arg) if isinstance(arg, str)
            else arg for arg in args]
    valid_file = []
    for argno, arg in enumerate(args):
        if isinstance(arg, str):
            if os.path.exists(arg):
                valid_file.append(True)
            elif arg.startswith("http"):
                valid_file.append(True)
            else:
                if os.path.exists(os.path.join(ytcfg.get("yt", "test_data_dir"), arg)):
                    valid_file.append(True)
                    args[argno] = os.path.join(ytcfg.get("yt", "test_data_dir"), arg)
                else:
                    valid_file.append(False)
        else:
            valid_file.append(False)
    if not any(valid_file):
        try:
            from yt.data_objects.time_series import DatasetSeries
            ts = DatasetSeries.from_filenames(*args, **kwargs)
            return ts
        except YTOutputNotIdentified:
            pass
        mylog.error("None of the arguments provided to load() is a valid file")
        mylog.error("Please check that you have used a correct path")
        raise YTOutputNotIdentified(args, kwargs)
    for n, c in output_type_registry.items():
        if n is None: continue
        if c._is_valid(*args, **kwargs): candidates.append(n)

    # convert to classes
    candidates = [output_type_registry[c] for c in candidates]
    # Find only the lowest subclasses, i.e. most specialised front ends
    candidates = find_lowest_subclasses(candidates)
    if len(candidates) == 1:
        return candidates[0](*args, **kwargs)
    if len(candidates) == 0:
        if ytcfg.get("yt", "enzo_db") != '' \
           and len(args) == 1 \
           and isinstance(args[0], str):
            erdb = EnzoRunDatabase()
            fn = erdb.find_uuid(args[0])
            n = "EnzoDataset"
            if n in output_type_registry \
               and output_type_registry[n]._is_valid(fn):
                return output_type_registry[n](fn)
        mylog.error("Couldn't figure out output type for %s", args[0])
        raise YTOutputNotIdentified(args, kwargs)

    mylog.error("Multiple output type candidates for %s:", args[0])
    for c in candidates:
        mylog.error("    Possible: %s", c)
    raise YTOutputNotIdentified(args, kwargs)
Exemplo n.º 21
0
    def write_image(self,
                    filename,
                    color_bounds=None,
                    channel=None,
                    cmap_name=None,
                    func=lambda x: x):
        r"""Writes a single channel of the ImageArray to a png file.

        Parameters
        ----------
        filename: string
            Note filename not be modified.

        Other Parameters
        ----------------
        channel: int
            Which channel to write out as an image. Defaults to 0
        cmap_name: string
            Name of the colormap to be used.
        color_bounds : tuple of floats, optional
            The min and max to scale between.  Outlying values will be clipped.
        cmap_name : string, optional
            An acceptable colormap.  See either yt.visualization.color_maps or
            http://www.scipy.org/Cookbook/Matplotlib/Show_colormaps .
        func : function, optional
            A function to transform the buffer before applying a colormap.

        Returns
        -------
        scaled_image : uint8 image that has been saved

        Examples
        --------

        >>> im = np.zeros([64,128])
        >>> for i in range(im.shape[0]):
        ...     im[i,:] = np.linspace(0.,0.3*i, im.shape[1])

        >>> myinfo = {'field':'dinosaurs', 'east_vector':np.array([1.,0.,0.]),
        ...     'north_vector':np.array([0.,0.,1.]), 'normal_vector':np.array([0.,1.,0.]),
        ...     'width':0.245, 'units':'cm', 'type':'rendering'}

        >>> im_arr = ImageArray(im, info=myinfo)
        >>> im_arr.write_image('test_ImageArray.png')

        """
        if cmap_name is None:
            cmap_name = ytcfg.get("yt", "default_colormap")
        if filename is not None and filename[-4:] != '.png':
            filename += '.png'

        #TODO: Write info dict as png metadata
        if channel is None:
            return write_image(self.swapaxes(0, 1).to_ndarray(),
                               filename,
                               color_bounds=color_bounds,
                               cmap_name=cmap_name,
                               func=func)
        else:
            return write_image(self.swapaxes(0, 1)[:, :, channel].to_ndarray(),
                               filename,
                               color_bounds=color_bounds,
                               cmap_name=cmap_name,
                               func=func)
    def __init__(self, filename,
                 dataset_type='fits',
                 auxiliary_files=[],
                 nprocs=None,
                 storage_filename=None,
                 nan_mask=None,
                 spectral_factor=1.0,
                 z_axis_decomp=False,
                 suppress_astropy_warnings=True,
                 parameters=None,
                 units_override=None):

        if parameters is None:
            parameters = {}
        parameters["nprocs"] = nprocs
        self.specified_parameters = parameters

        self.z_axis_decomp = z_axis_decomp
        self.spectral_factor = spectral_factor

        if suppress_astropy_warnings:
            warnings.filterwarnings('ignore', module="astropy", append=True)
        auxiliary_files = ensure_list(auxiliary_files)
        self.filenames = [filename] + auxiliary_files
        self.num_files = len(self.filenames)
        self.fluid_types += ("fits",)
        if nan_mask is None:
            self.nan_mask = {}
        elif isinstance(nan_mask, float):
            self.nan_mask = {"all":nan_mask}
        elif isinstance(nan_mask, dict):
            self.nan_mask = nan_mask
        self._handle = FITSFileHandler(self.filenames[0])
        if (isinstance(self.filenames[0], _astropy.pyfits.hdu.image._ImageBaseHDU) or
            isinstance(self.filenames[0], _astropy.pyfits.HDUList)):
            fn = "InMemoryFITSFile_%s" % uuid.uuid4().hex
        else:
            fn = self.filenames[0]
        self._handle._fits_files.append(self._handle)
        if self.num_files > 1:
            for fits_file in auxiliary_files:
                if isinstance(fits_file, _astropy.pyfits.hdu.image._ImageBaseHDU):
                    f = _astropy.pyfits.HDUList([fits_file])
                elif isinstance(fits_file, _astropy.pyfits.HDUList):
                    f = fits_file
                else:
                    if os.path.exists(fits_file):
                        fn = fits_file
                    else:
                        fn = os.path.join(ytcfg.get("yt","test_data_dir"),fits_file)
                    f = _astropy.pyfits.open(fn, memmap=True,
                                             do_not_scale_image_data=True,
                                             ignore_blank=True)
                self._handle._fits_files.append(f)

        if len(self._handle) > 1 and self._handle[1].name == "EVENTS":
            self.events_data = True
            self.first_image = 1
            self.primary_header = self._handle[self.first_image].header
            self.naxis = 2
            self.wcs = _astropy.pywcs.WCS(naxis=2)
            self.events_info = {}
            for k,v in self.primary_header.items():
                if k.startswith("TTYP"):
                    if v.lower() in ["x","y"]:
                        num = k.strip("TTYPE")
                        self.events_info[v.lower()] = (self.primary_header["TLMIN"+num],
                                                       self.primary_header["TLMAX"+num],
                                                       self.primary_header["TCTYP"+num],
                                                       self.primary_header["TCRVL"+num],
                                                       self.primary_header["TCDLT"+num],
                                                       self.primary_header["TCRPX"+num])
                    elif v.lower() in ["energy","time"]:
                        num = k.strip("TTYPE")
                        unit = self.primary_header["TUNIT"+num].lower()
                        if unit.endswith("ev"): unit = unit.replace("ev","eV")
                        self.events_info[v.lower()] = unit
            self.axis_names = [self.events_info[ax][2] for ax in ["x","y"]]
            self.reblock = 1
            if "reblock" in self.specified_parameters:
                self.reblock = self.specified_parameters["reblock"]
            self.wcs.wcs.cdelt = [self.events_info["x"][4]*self.reblock,
                                  self.events_info["y"][4]*self.reblock]
            self.wcs.wcs.crpix = [(self.events_info["x"][5]-0.5)/self.reblock+0.5,
                                  (self.events_info["y"][5]-0.5)/self.reblock+0.5]
            self.wcs.wcs.ctype = [self.events_info["x"][2],self.events_info["y"][2]]
            self.wcs.wcs.cunit = ["deg","deg"]
            self.wcs.wcs.crval = [self.events_info["x"][3],self.events_info["y"][3]]
            self.dims = [(self.events_info["x"][1]-self.events_info["x"][0])/self.reblock,
                         (self.events_info["y"][1]-self.events_info["y"][0])/self.reblock]
        else:
            self.events_data = False
            # Sometimes the primary hdu doesn't have an image
            if len(self._handle) > 1 and self._handle[0].header["naxis"] == 0:
                self.first_image = 1
            else:
                self.first_image = 0
            self.primary_header = self._handle[self.first_image].header
            self.naxis = self.primary_header["naxis"]
            self.axis_names = [self.primary_header.get("ctype%d" % (i+1),"LINEAR")
                               for i in range(self.naxis)]
            self.dims = [self.primary_header["naxis%d" % (i+1)]
                         for i in range(self.naxis)]
            wcs = _astropy.pywcs.WCS(header=self.primary_header)
            if self.naxis == 4:
                self.wcs = _astropy.pywcs.WCS(naxis=3)
                self.wcs.wcs.crpix = wcs.wcs.crpix[:3]
                self.wcs.wcs.cdelt = wcs.wcs.cdelt[:3]
                self.wcs.wcs.crval = wcs.wcs.crval[:3]
                self.wcs.wcs.cunit = [str(unit) for unit in wcs.wcs.cunit][:3]
                self.wcs.wcs.ctype = [type for type in wcs.wcs.ctype][:3]
            else:
                self.wcs = wcs

        self.refine_by = 2

        Dataset.__init__(self, fn, dataset_type, units_override=units_override)
        self.storage_filename = storage_filename
Exemplo n.º 23
0
    def __init__(self,
                 outputs,
                 indices,
                 fields=None,
                 suppress_logging=False,
                 ptype=None):

        indices.sort()  # Just in case the caller wasn't careful
        self.field_data = YTFieldData()
        self.data_series = outputs
        self.masks = []
        self.sorts = []
        self.array_indices = []
        self.indices = indices
        self.num_indices = len(indices)
        self.num_steps = len(outputs)
        self.times = []
        self.suppress_logging = suppress_logging
        self.ptype = ptype if ptype else "all"

        if fields is None:
            fields = []

        if self.suppress_logging:
            old_level = int(ytcfg.get("yt", "log_level"))
            mylog.setLevel(40)
        ds_first = self.data_series[0]
        dd_first = ds_first.all_data()

        fds = {}
        for field in (
                "particle_index",
                "particle_position_x",
                "particle_position_y",
                "particle_position_z",
        ):
            fds[field] = dd_first._determine_fields((self.ptype, field))[0]

        # Note: we explicitly pass dynamic=False to prevent any change in piter from
        # breaking the assumption that the same processors load the same datasets
        my_storage = {}
        pbar = get_pbar("Constructing trajectory information",
                        len(self.data_series))
        for i, (sto, ds) in enumerate(
                self.data_series.piter(storage=my_storage, dynamic=False)):
            dd = ds.all_data()
            newtags = dd[fds["particle_index"]].d.astype("int64")
            mask = np.in1d(newtags, indices, assume_unique=True)
            sort = np.argsort(newtags[mask])
            array_indices = np.where(
                np.in1d(indices, newtags, assume_unique=True))[0]
            self.array_indices.append(array_indices)
            self.masks.append(mask)
            self.sorts.append(sort)

            pfields = {}
            for field in (f"particle_position_{ax}" for ax in "xyz"):
                pfields[field] = dd[fds[field]].ndarray_view()[mask][sort]

            sto.result_id = ds.parameter_filename
            sto.result = (ds.current_time, array_indices, pfields)
            pbar.update(i + 1)
        pbar.finish()

        if self.suppress_logging:
            mylog.setLevel(old_level)

        sorted_storage = sorted(my_storage.items())
        _fn, (time, *_) = sorted_storage[0]
        time_units = time.units
        times = [time.to(time_units) for _fn, (time, *_) in sorted_storage]
        self.times = self.data_series[0].arr([time.value for time in times],
                                             time_units)

        self.particle_fields = []
        output_field = np.empty((self.num_indices, self.num_steps))
        output_field.fill(np.nan)
        for field in (f"particle_position_{ax}" for ax in "xyz"):
            for i, (_fn, (_time, indices,
                          pfields)) in enumerate(sorted_storage):
                try:
                    # This will fail if particles ids are
                    # duplicate. This is due to the fact that the rhs
                    # would then have a different shape as the lhs
                    output_field[indices, i] = pfields[field]
                except ValueError as e:
                    raise YTIllDefinedParticleData(
                        "This dataset contains duplicate particle indices!"
                    ) from e
            self.field_data[field] = array_like_field(dd_first,
                                                      output_field.copy(),
                                                      fds[field])
            self.particle_fields.append(field)

        # Instantiate fields the caller requested
        self._get_data(fields)