コード例 #1
0
ファイル: xdmf.py プロジェクト: KristoforMaynard/Viscid
    def _parse_time(self, timetag):
        """ returns the time(s) as float, or numpy array, time attributes"""
        attrs = self._fill_attrs(timetag)
        timetype = attrs["Type"]

        if timetype == 'Single':
            return float(timetag.get('Value')), attrs
        elif timetype == 'List':
            return self._parse_dataitem(timetag.find('.//DataItem'))[0], attrs
        elif timetype == 'Range':
            raise NotImplementedError("Time Range not yet implemented")
            # dat, dattrs = self._parse_dataitem(timetag.find('.//DataItem'))
            # TODO: this is not the most general, but I think it'll work
            # as a first stab, plus I will probably not need Range ever
            # tgridtag = timetag.find("ancestor::Grid[@GridType='Collection']"
            #                         "[@CollectionType='Temporal'][1]"))
            # n = len(tgridtag.find(.//Grid[@GridType='Collection']
            #         [CollectionType=['Spatial']]))
            # return np.linspace(dat[0], dat[1], n)
            # return np.arange(dat[0], dat[1])
        elif timetype == 'HyperSlab':
            dat, dattrs = self._parse_dataitem(timetag.find('.//DataItem'))
            arr = np.array([dat[0] + i * dat[1] for i in range(int(dat[2]))])
            return arr, attrs
        else:
            logger.warning("invalid TimeType.\n")
コード例 #2
0
ファイル: openggcm.py プロジェクト: KristoforMaynard/Viscid
    def read_logfile(self):
        if self.read_log_file:
            log_basename = "{0}.log".format(self.find_info('run'))
            # FYI, default max_depth should be 8
            log_fname = find_file_uptree(self.dirname, log_basename)
            if log_fname is None:
                log_fname = find_file_uptree(".", log_basename)
            if log_fname is None:
                log_fname = find_file_uptree(self.dirname, "log.txt")
            if log_fname is None:
                log_fname = find_file_uptree(self.dirname, "log.log")
            if log_fname is None:
                log_fname = find_file_uptree(self.dirname, "log")

            if log_fname is not None:
                self.set_info("_viscid_log_fname", log_fname)

                with GGCMLogFile(log_fname) as log_f:
                    # self._info.update(log_f.info)
                    for key, val in log_f.info.items():
                        self.update_info(key, val)
            else:
                # print("**", log_f)
                self.set_info("_viscid_log_fname", None)
                if not GGCMFile._already_warned_about_logfile:
                    logger.warning("You wanted to read parameters from the "
                                   "logfile, but I couldn't find one. Maybe "
                                   "you need to copy it from somewhere?")
                    GGCMFile._already_warned_about_logfile = True
        else:
            self.set_info("_viscid_log_fname", False)
コード例 #3
0
def run_div_test(fld, exact, title='', show=False, ignore_inexact=False):
    t0 = time()
    result_numexpr = viscid.div(fld, preferred="numexpr", only=False)
    t1 = time()
    logger.info("numexpr magnitude runtime: %g", t1 - t0)

    result_diff = viscid.diff(result_numexpr, exact)['x=1:-1, y=1:-1, z=1:-1']
    if not ignore_inexact and not (result_diff.data < 5e-5).all():
        logger.warning("numexpr result is far from the exact result")
    logger.info("min/max(abs(numexpr - exact)): %g / %g",
                np.min(result_diff.data), np.max(result_diff.data))

    planes = ["y=0j", "z=0j"]
    nrows = 2
    ncols = len(planes)
    _, axes = plt.subplots(nrows, ncols, squeeze=False)

    for i, p in enumerate(planes):
        vlt.plot(result_numexpr, p, ax=axes[0, i], show=False)
        vlt.plot(result_diff, p, ax=axes[1, i], show=False)

    plt.suptitle(title)
    vlt.auto_adjust_subplots(subplot_params=dict(top=0.9))

    plt.savefig(next_plot_fname(__file__))
    if show:
        vlt.mplshow()
コード例 #4
0
ファイル: vseaborn.py プロジェクト: KristoforMaynard/Viscid
def activate_from_viscid():
    """You should not need to call this

    This function is called by viscid.plot.vpyplot at time of import.
    All you need to do is set the global config options above before
    importing `viscid.plot.vpyplot`. This can be done with the rc file.
    """
    if enabled:
        from distutils.version import LooseVersion
        import matplotlib

        if LooseVersion(matplotlib.__version__) >= LooseVersion("1.5.0"):
            logger.warning("Using this shim to seaborn for pretty plots "
                           "is deprecated since you have matplotlib >= 1.5.\n"
                           "Instead, use matplotlib's style sheets through "
                           "`viscid.mpl_style`.")

        try:
            import seaborn as real_seaborn
            # just some fancyness so i can specify an arbitrary
            # color palette function with arbitrary arguments
            # from the rc file
            _palette_func = real_seaborn.color_palette
            if isinstance(palette, (list, tuple)):
                _palette = palette
                # ##### NOPE, letting the user run an attribute of
                # ##### seaborn might be a security problem...
                # try:
                #     func_name = palette[0].strip().strip("_")
                #     func = getattr(real_seaborn, func_name)
                #     if hasattr(func, "__call__"):
                #         _palette_func = func
                #         _palette = palette[1:]
                # except (AttributeError, TypeError):
                #     pass
                # #####
            else:
                _palette = [palette]
            _palette = _palette_func(*_palette)

            # Ok, now set the defaults
            real_seaborn.set(context=context,
                             style=style,
                             palette=_palette,
                             font=font,
                             font_scale=font_scale,
                             rc=rc)

            # now pull in the public namespace
            g = globals()
            for key, value in real_seaborn.__dict__.items():
                if not key.startswith("_"):
                    g[key] = value
            g['SEABORN_ACTIVATED'] = True
            viscid.plot.mpl_extra.post_rc_actions(show_warning=False)
            viscid.mpl_style.post_rc_actions(show_warning=False)
        except ImportError:
            logger.warning("seaborn package not installed")
コード例 #5
0
ファイル: vseaborn.py プロジェクト: KristoforMaynard/Viscid
def activate_from_viscid():
    """You should not need to call this

    This function is called by viscid.plot.vpyplot at time of import.
    All you need to do is set the global config options above before
    importing `viscid.plot.vpyplot`. This can be done with the rc file.
    """
    if enabled:
        from distutils.version import LooseVersion
        import matplotlib

        if LooseVersion(matplotlib.__version__) >= LooseVersion("1.5.0"):
            logger.warning("Using this shim to seaborn for pretty plots "
                           "is deprecated since you have matplotlib >= 1.5.\n"
                           "Instead, use matplotlib's style sheets through "
                           "`viscid.mpl_style`.")

        try:
            import seaborn as real_seaborn
            # just some fancyness so i can specify an arbitrary
            # color palette function with arbitrary arguments
            # from the rc file
            _palette_func = real_seaborn.color_palette
            if isinstance(palette, (list, tuple)):
                _palette = palette
                # ##### NOPE, letting the user run an attribute of
                # ##### seaborn might be a security problem...
                # try:
                #     func_name = palette[0].strip().strip("_")
                #     func = getattr(real_seaborn, func_name)
                #     if hasattr(func, "__call__"):
                #         _palette_func = func
                #         _palette = palette[1:]
                # except (AttributeError, TypeError):
                #     pass
                # #####
            else:
                _palette = [palette]
            _palette = _palette_func(*_palette)

            # Ok, now set the defaults
            real_seaborn.set(context=context, style=style,
                             palette=_palette, font=font,
                             font_scale=font_scale, rc=rc)

            # now pull in the public namespace
            g = globals()
            for key, value in real_seaborn.__dict__.items():
                if not key.startswith("_"):
                    g[key] = value
            g['SEABORN_ACTIVATED'] = True
            viscid.plot.mpl_extra.post_rc_actions(show_warning=False)
            viscid.mpl_style.post_rc_actions(show_warning=False)
        except ImportError:
            logger.warning("seaborn package not installed")
コード例 #6
0
    def xinclude(tree, base_url=None, **kwargs):
        """Summary

        Args:
            tree (Tree): The object returned by parse
            base_url (str): Not used
            **kwargs: passed to tree.xinclude()
        """
        # TODO: ignore if an xincluded xdmf file doesn't exist?
        if base_url:
            logger.warning("lxml will ignore base_url: %s", base_url)
        return tree.xinclude(**kwargs)
コード例 #7
0
ファイル: sliceutil.py プロジェクト: KristoforMaynard/Viscid
def _warn_deprecated_float(val, varname='value'):
    global _emit_deprecated_float_warning  # pylint: disable=global-statement
    if _emit_deprecated_float_warning:
        frame = _user_written_stack_frame()
        s = ("DEPRECATION...\n"
             "Slicing by float is deprecated. Slicing by location is now \n"
             "performed with an imaginary number, or a string with a trailing \n"
             "'f', as in 0j, 'x=0j', or 'x=0f'. This warning comes from:\n"
             "    {0}:{1}\n"
             "    >>> {2}"
             "".format(frame[1], frame[2], frame[4][0].strip()))
        logger.warning(s)
        _emit_deprecated_float_warning = False
コード例 #8
0
 def load_file(self, fname, index_handle=True, **kwargs):
     """ load a single file and return a vFile instance, not a list
     of vFiles like load does
     """
     fls = self.load_files(fname, index_handle=index_handle, **kwargs)
     if len(fls) == 0:
         return None
     else:
         if len(fls) > 1:
             logger.warning(
                 "Loaded > 1 file for %s, did you mean to call "
                 "load_files()?", fname)
         return fls[0]
コード例 #9
0
def _warn_deprecated_float(val, varname='value'):
    global _emit_deprecated_float_warning  # pylint: disable=global-statement
    if _emit_deprecated_float_warning:
        frame = _user_written_stack_frame()
        s = (
            "DEPRECATION...\n"
            "Slicing by float is deprecated. Slicing by location is now \n"
            "performed with an imaginary number, or a string with a trailing \n"
            "'f', as in 0j, 'x=0j', or 'x=0f'. This warning comes from:\n"
            "    {0}:{1}\n"
            "    >>> {2}"
            "".format(frame[1], frame[2], frame[4][0].strip()))
        logger.warning(s)
        _emit_deprecated_float_warning = False
コード例 #10
0
ファイル: evaluator.py プロジェクト: KristoforMaynard/Viscid
def _evaluate_numexpr(grid, result_name, eqn, slc=Ellipsis):
    """
    Returns:
        Field

    Raises:
        RuntimeError, if no numexpr, or if evaluate is not enabled
        TypeError, if numexpr couldn't understand a math input
    """
    if not _has_numexpr:
        raise RuntimeError("Evaluate not enabled, or numexpr not installed.")

    # salt symbols that don't look like math functions and look them up
    # in the grid
    salt = "SALT"
    _symbol_re = r"\b([_A-Za-z][_a-zA-Z0-9]*)\b"
    var_re = _symbol_re + r"(?!\s*\()"
    flds = []
    # for security
    eqn = eqn.replace("__", "")
    local_dict = dict()

    def var_salter(symbols):
        symbol = symbols.groups()[0]
        salted_symbol = salt + symbol
        # yes, i'm not using dict.update on purpose since grid's
        # getitem might copy the array
        if salted_symbol not in local_dict:
            this_fld = grid.get_field(symbol, slc=slc)
            local_dict[salted_symbol] = this_fld
            if len(flds) == 0:
                if isinstance(this_fld, field.Field):
                    flds.append(this_fld)
                else:
                    raise RuntimeError("reduced to scalar, no need for numexpr")
        return salted_symbol

    salted_eqn = re.sub(var_re, var_salter, eqn)

    arr = ne.evaluate(salted_eqn, local_dict=local_dict,
                      global_dict={"__builtins__": {}})

    # FIXME: actually detect the type of field instead of asserting it's
    # a scalar... also maybe auto-detect reduction operations?
    if len(flds) > 0:
        ctx = dict(name=result_name, pretty_name=result_name)
        return flds[0].wrap(arr, context=ctx)
    else:
        logger.warning("Strange input to numexpr evaluator: %s", eqn)
        return field.wrap_field(arr, grid.crds, name=result_name)
コード例 #11
0
ファイル: evaluator.py プロジェクト: KristoforMaynard/Viscid
def evaluate(grid, result_name, eqn, try_numexpr=True, slc=Ellipsis):
    """Evaluate an equation on a grid

    Examples:
        To use this function directly

            >>> evaluator.enabled = True
            >>> f = viscid.load_file("...")
            >>> evaluator.evaluate(f.get_grid(),
                                   "sqrt(vx**2+vy**2+vz**2)",
                                   "speed")
            <viscid.field.ScalarField object at ...>

        Or, for short, you can as a grid to evaluate implicitly,

            >> evaluator.enabled = True
            >> f = viscid.load_file("...")
            >> speed = f["speed=sqrt(vx**2+vy**2+vz**2)"]
            <viscid.field.ScalarField object at ...>

    Parameters:
        grid: a grid instance where the fields live
        result_name (str): Used for the name and pretty_name of the
            resulting field
        eqn (str): the equation, if a symbol exists in the numpy
            namespace, then that's how it is interpreted, otherwise,
            the symbol will be looked up in the grid

    Returns:
        Field instance
    """
    if not enabled:
        raise RuntimeError("You must enable the evaluator with "
                           "`viscid.calculator.evaluator.enabled = True`, "
                           "or in your viscidrc.")

    if try_numexpr:
        try:
            return _evaluate_numexpr(grid, result_name, eqn, slc=slc)
        except RuntimeError:
            pass
        except TypeError:
            logger.warning("Numexpr couldn't understand a math function you "
                           "tried to use in '{0}', falling back to numpy"
                           "".format(eqn))
    return _evaluate_numpy(grid, result_name, eqn, slc=slc)
コード例 #12
0
ファイル: mpl_style.py プロジェクト: KristoforMaynard/Viscid
    def post_rc_actions(show_warning=True):
        try:
            from matplotlib import style

            if u"viscid-default" not in use_styles:
                use_styles.insert(0, u"viscid-default")

            for s in use_styles:
                try:
                    style.use(s)
                except ValueError as e:
                    logger.warning(str(e))
        except ImportError:
            if show_warning and use_styles:
                logger.warning(
                    "Upgrade to matplotlib >= 1.5.0 to use style sheets")

        matplotlib.rcParams.update(rc_params)
        for group, params in rc.items():
            matplotlib.rc(group, **params)
コード例 #13
0
ファイル: xdmf.py プロジェクト: KristoforMaynard/Viscid
    def _parse_dataitem(self, item, keep_flat=False):
        """ returns the data as a numpy array, or HDF data item """
        attrs = self._fill_attrs(item)

        dimensions = attrs["Dimensions"]
        if dimensions:
            dimensions = [int(d) for d in dimensions.split(' ')]

        numbertype = attrs["NumberType"]
        precision = attrs["Precision"]
        nptype = np.dtype({'Float': 'f', 'Int': 'i', 'UInt': 'u',
                           'Char': 'i', 'UChar': 'u'}[numbertype] + str(precision))

        fmt = attrs["Format"]

        if fmt == "XML":
            arr = np.fromstring(item.text, sep=' ', dtype=nptype)
            if dimensions and not keep_flat:
                arr = arr.reshape(dimensions)
            return arr, attrs

        if fmt == "HDF":
            fname, loc = item.text.strip().split(':')

            # FIXME: startswith '/' is unix path name specific
            if self.h5_root_dir is not None:
                fname = os.path.join(self.h5_root_dir, fname)
            elif not fname.startswith('/'):
                fname = os.path.join(self.dirname, fname)
            h5file = self._load_child_file(fname, index_handle=False,
                                           file_type=FileLazyHDF5)
            arr = h5file.get_data(loc)
            return arr, attrs

        if fmt == "Binary":
            raise NotImplementedError("binary xdmf data not implemented")

        logger.warning("Invalid DataItem Format.")
        return (None, None)
コード例 #14
0
ファイル: sliceutil.py プロジェクト: KristoforMaynard/Viscid
def std_sel2index(std_sel, crd_arr, val_endpoint=True, interior=False,
                  tdunit='s', epoch=None):
    """Turn single standardized selection into slice (by int or None)

    Normally (val_endpoint=True, interior=False), the rules for float
    lookup val_endpoints are::

        - The slice will never include an element whose value in arr
          is < start (or > if the slice is backward)
        - The slice will never include an element whose value in arr
          is > stop (or < if the slice is backward)
        - !! The slice WILL INCLUDE stop if you don't change
          val_endpoint. This is different from normal slicing, but
          it's more natural when specifying a slice as a float.

    If interior=True, then the slice is expanded such that start and
    stop are interior to the sliced array.

    Args:
        std_sel: single standardized selection
        arr (ndarray): filled with floats to do the lookup
        val_endpoint (bool): iff True then include stop in the slice when
            slicing-by-value (DOES NOT EFFECT SLICE-BY-INDEX).
            Set to False to get python slicing symantics when it
            comes to excluding stop, but fair warning, python
            symantics feel awkward here. Consider the case
            [0.1, 0.2, 0.3][:0.25]. If you think this should include
            0.2, then leave keep val_endpoint=True.
        interior (bool): if True, then extend both ends of the slice
            such that slice-by-location endpoints are interior to the
            slice
        epoch (datetime64-like): Epoch for to go datetime64 <-> float
        tdunit (str): Presumed time unit for floats
        tol (int): number of machine epsilons to consider
            "close enough"
    """
    idx = None

    if interior and not val_endpoint:
        logger.warning("For interior slices, val_endpoint must be True, I'll "
                       "change that for you.")
        val_endpoint = True

    if isinstance(std_sel, slice):
        assert isinstance(std_sel.step, (int, np.integer, type(None)))
        start_val = None
        stop_val = None

        orig_step = std_sel.step
        ustep = 1 if std_sel.step is None else int(std_sel.step)
        sgn = np.sign(ustep)

        if (isinstance(std_sel.start, (int, np.integer, type(None)))
            and not isinstance(std_sel.start, (np.datetime64, np.timedelta64))):
            ustart = std_sel.start
        else:
            ustart, tol = _unify_sbv_types(std_sel.start, crd_arr, tdunit='s',
                                           epoch=epoch)
            start_val = ustart
            diff = crd_arr - ustart + (tol * sgn)
            zero = np.array([0]).astype(diff.dtype)[0]

            if ustep > 0:
                diff = np.ma.masked_less(diff, zero)
            else:
                diff = np.ma.masked_greater(diff, zero)

            if np.ma.count(diff) == 0:
                # start value is past the wrong end of the array
                if ustep > 0:
                    ustart = len(crd_arr)
                else:
                    # start = -len(arr) - 1
                    # having a value < -len(arr) won't play
                    # nice with make_fwd_slice, but in this
                    # case, the slice will have no data, so...
                    return slice(0, 0, ustep)
            else:
                ustart = np.argmin(np.abs(diff))

        if (isinstance(std_sel.stop, (int, np.integer, type(None)))
            and not isinstance(std_sel.stop, (np.datetime64, np.timedelta64))):
            ustop = std_sel.stop
        else:
            ustop, tol = _unify_sbv_types(std_sel.stop, crd_arr, tdunit='s',
                                          epoch=epoch)
            stop_val = ustop
            diff = crd_arr - ustop - (tol * sgn)
            zero = np.array([0]).astype(diff.dtype)[0]

            if ustep > 0:
                diff = np.ma.masked_greater(diff, zero)
            else:
                diff = np.ma.masked_less(diff, zero)

            if ustep > 0:
                if ustop < crd_arr[0]:
                    # stop value is past the wong end of the array
                    ustop = 0
                else:
                    ustop = int(np.argmin(np.abs(diff)))
                    if val_endpoint:
                        ustop += 1
            else:
                if ustop > crd_arr[-1]:
                    # stop value is past the wrong end of the array
                    ustop = len(crd_arr)
                else:
                    ustop = int(np.argmin(np.abs(diff)))
                    if val_endpoint:
                        if ustop > 0:
                            ustop -= 1
                        else:
                            # 0 - 1 == -1 which would wrap to the end of
                            # of the array... instead, just make it None
                            ustop = None
        idx = slice(ustart, ustop, orig_step)

        if interior:
            _a, _b, _c = _interiorize_slice(crd_arr, start_val, stop_val,
                                            idx.start, idx.stop, idx.step)
            idx = slice(_a, _b, _c)

    else:
        # slice by single value or ndarray of single values (int, float, times)
        usel, _ = _unify_sbv_types(std_sel, crd_arr, tdunit='s', epoch=epoch)
        if (isinstance(usel, (int, np.integer, type(None)))
            and not isinstance(usel, (np.datetime64, np.timedelta64))):
            idx = usel
        elif isinstance(usel, np.ndarray):
            if isinstance(usel[0, 0], np.integer):
                idx = usel.reshape(-1)
            else:
                idx = np.argmin(np.abs(crd_arr.reshape(-1, 1) - usel), axis=0)
        else:
            idx = np.argmin(np.abs(crd_arr - usel))

    return idx
コード例 #15
0
ファイル: multiplot.py プロジェクト: KristoforMaynard/Viscid
def _do_multiplot(tind, grid, plot_vars=None, global_popts=None, kwopts=None,
                  share_axes=False, show=False, subplot_params=None,
                  first_run_result=None, first_run=False, **kwargs):
    from viscid.plot import vpyplot as vlt
    import matplotlib.pyplot as plt

    logger.info("Plotting timestep: %d, %g", tind, grid.time)

    if plot_vars is None:
        raise ValueError("No plot_vars given to `_do_multiplot` :(")
    if kwargs:
        logger.info("Unused kwargs: {0}".format(kwargs))

    if kwopts is None:
        kwopts = {}
    transpose = kwopts.get("transpose", False)
    plot_size = kwopts.get("plot_size", None)
    dpi = kwopts.get("dpi", None)
    out_prefix = kwopts.get("out_prefix", None)
    out_format = kwopts.get("out_format", "png")
    selection = kwopts.get("selection", None)
    timeformat = kwopts.get("timeformat", ".02f")
    tighten = kwopts.get("tighten", False)
    # wicked hacky
    # subplot_params = kwopts.get("subplot_params", _subplot_params)

    # nrows = len(plot_vars)
    nrows = len([pv[0] for pv in plot_vars if not pv[0].startswith('^')])
    ncols = 1
    if transpose:
        nrows, ncols = ncols, nrows

    if nrows == 0:
        logger.warning("I have no variables to plot")
        return

    fig = plt.gcf()
    if plot_size is not None:
        fig.set_size_inches(*plot_size, forward=True)
    if dpi is not None:
        fig.set_dpi(dpi)

    shareax = None

    this_row = -1
    for i, fld_meta in enumerate(plot_vars):
        if not fld_meta[0].startswith('^'):
            this_row += 1
            same_axis = False
        else:
            same_axis = True

        fld_name_meta = fld_meta[0].lstrip('^')
        fld_name_split = fld_name_meta.split(',')
        if '=' in fld_name_split[0]:
            # if fld_name is actually an equation, assume
            # there's no slice, and commas are part of the
            # equation
            fld_name = ",".join(fld_name_split)
            fld_slc = ""
        else:
            fld_name = fld_name_split[0]
            fld_slc = ",".join(fld_name_split[1:])
        if selection is not None:
            # fld_slc += ",{0}".format(selection)
            if fld_slc != "":
                fld_slc = ",".join([fld_slc, selection])
            else:
                fld_slc = selection
        if fld_slc.strip() == "":
            fld_slc = Ellipsis

        # print("fld_time:", fld.time)
        if this_row < 0:
            raise ValueError("first plot can't begin with a +")
        row = this_row
        col = 0
        if transpose:
            row, col = col, row
        if not same_axis:
            ax = plt.subplot2grid((nrows, ncols), (row, col),
                                  sharex=shareax, sharey=shareax)
        if i == 0 and share_axes:
            shareax = ax

        if "plot_opts" not in fld_meta[1]:
            fld_meta[1]["plot_opts"] = global_popts
        elif global_popts is not None:
            fld_meta[1]["plot_opts"] = "{0},{1}".format(
                fld_meta[1]["plot_opts"], global_popts)

        with grid.get_field(fld_name, slc=fld_slc) as fld:
            vlt.plot(fld, masknan=True, **fld_meta[1])
        # print("fld cache", grid[fld_meta[0]]._cache)

    if timeformat and timeformat.lower() != "none":
        plt.suptitle(grid.format_time(timeformat))

    # for adjusting subplots / tight_layout and applying the various
    # hacks to keep plots from dancing around in movies
    if not subplot_params and first_run_result:
        subplot_params = first_run_result
    if tighten:
        tighten = dict(rect=[0, 0.03, 1, 0.90])
    ret = vlt.auto_adjust_subplots(tight_layout=tighten,
                                   subplot_params=subplot_params)
    if not first_run:
        ret = None

    if out_prefix:
        plt.savefig("{0}_{1:06d}.{2}".format(out_prefix, tind + 1, out_format))
    if show:
        plt.show()
    plt.clf()

    return ret
コード例 #16
0
ファイル: xdmf.py プロジェクト: KristoforMaynard/Viscid
    def _parse_grid(self, el, parent_node=None, time=None):
        attrs = self._fill_attrs(el)
        grd = None
        crds = None

        # parse topology, or cascade parent grid's topology
        topology = el.find("./Topology")
        topoattrs = None
        if topology is not None:
            topoattrs = self._fill_attrs(topology)
        elif parent_node and parent_node.topology_info:
            topoattrs = parent_node.topology_info

        # parse geometry, or cascade parent grid's geometry
        geometry = el.find("./Geometry")
        geoattrs = None
        if geometry is not None:
            crds, geoattrs = self._parse_geometry(geometry, topoattrs)
        elif parent_node and parent_node.geometry_info:
            geoattrs = parent_node.geometry_info
            crds = parent_node.crds  # this can be None and that's ok

        # parse time
        if time is None:
            t = el.find("./Time")
            if t is not None:
                pt, tattrs = self._parse_time(t)
                if tattrs["Type"] == "Single":
                    time = pt
        # cascade a parent grid's time
        if time is None and parent_node and parent_node.time is not None:
            time = parent_node.time

        gt = attrs["GridType"]
        if gt == "Collection":
            times = None
            ct = attrs["CollectionType"]
            if ct == "Temporal":
                grd = self._make_dataset(parent_node, dset_type="temporal",
                                         name=attrs["Name"])
                self._inject_info(el, grd)
                ttag = el.find("./Time")
                if ttag is not None:
                    times, tattrs = self._parse_time(ttag)
            elif ct == "Spatial":
                grd = self._make_dataset(parent_node, name=attrs["Name"])
                self._inject_info(el, grd)
            else:
                logger.warning("Unknown collection type %s, ignoring grid", ct)

            for i, subgrid in enumerate(el.findall("./Grid")):
                t = times[i] if (times is not None and i < len(times)) else time
                # print(subgrid, grd, t)
                self._parse_grid(subgrid, parent_node=grd, time=time)
            if len(grd.children) > 0:
                grd.activate(0)

        elif gt == "Uniform":
            if not (topoattrs and geoattrs):
                logger.warning("Xdmf Uniform grids must have "
                               "topology / geometry.")
            else:
                grd = self._make_grid(parent_node, name=attrs["Name"],
                                      **self._grid_opts)
                self._inject_info(el, grd)
                for attribute in el.findall("./Attribute"):
                    fld = self._parse_attribute(grd, attribute, crds,
                                                topoattrs, time)
                    if time:
                        fld.time = time
                    grd.add_field(fld)

        elif gt == "Tree":
            logger.warning("Xdmf Tree Grids not implemented, ignoring "
                           "this grid")
        elif gt == "Subset":
            logger.warning("Xdmf Subset Grids not implemented, ignoring "
                           "this grid")
        else:
            logger.warning("Unknown grid type %s, ignoring this grid", gt)

        # fill attributes / data items
        # if grid and gt == "Uniform":
        #     for a in el.findall("./Attribute"):
        #         fld = self._parse_attribute(a)
        #         grid.add_field(fld)

        if grd:
            if time is not None:
                grd.time = time
            if topoattrs is not None:
                grd.topology_info = topoattrs
            if geoattrs is not None:
                grd.geometry_info = geoattrs
            if crds is not None:
                grd.set_crds(crds)

            # EXPERIMENTAL AMR support, _last_amr_grid shouldn't be an attribute
            # of self, since that will only remember the most recently generated
            # amr grid, but that's ok for now
            # if gt == "Uniform":
            #     print(">!", crds._TYPE, crds.xl_nc, grd.time)
            #     print(">!?", type(parent_node), parent_node.children._ordered,
            #           len(parent_node.children))
            if gt == "Collection" and ct == "Spatial":
                grd, is_amr = amr_grid.dataset_to_amr_grid(grd,
                                                           self._last_amr_skeleton)
                if is_amr:
                    self._last_amr_skeleton = grd.skeleton

            if parent_node is not None:
                parent_node.add(grd)

        return grd  # can be None
コード例 #17
0
ファイル: openggcm.py プロジェクト: KristoforMaynard/Viscid
    def _do_mhd_to_gse_on_read(self):
        """Return True if we """
        # we already know what this data file needs
        if self.has_info("_viscid_do_mhd_to_gse_on_read"):
            return self.find_info("_viscid_do_mhd_to_gse_on_read")

        # do we already know the crd system of this grid?
        crd_system = self.find_info("crd_system", None)
        freshly_determined_crd_system = crd_system is None

        # I guess not, can we figure out the crd system of this grid?
        if crd_system is None and self.find_info('assume_mhd_crds', False):
            crd_system = "mhd"

        if crd_system is None and self.find_info("_viscid_log_fname"):
            # try to intuit the _crd system based on the log file and grid
            try:
                # if we're using a mirdip IC, and low edge is at least
                # twice smaller than the high edge, then assume
                # it's a magnetosphere box with xl < 0.0 is the sunward
                # edge in "MHD" coordinates
                is_openggcm = self.find_info('ggcm_mhd_type') == "ggcm"
                # this 2nd check is in case the ggcm_mhd view in the
                # log file is mangled... this happens sometimes
                ic_type = self.find_info('ggcm_mhd_ic_type', '')
                is_openggcm |= ic_type.startswith("mirdip")
                # note that these default values are total hacks for fortran
                # runs which don't spew mrc information @ the beginning
                xl = float(self.find_info('mrc_crds_l')[0])
                xh = float(self.find_info('mrc_crds_h')[0])
                if is_openggcm and xl < 0.0 and xh > 0.0 and -2 * xl < xh:
                    crd_system = "mhd"
                elif is_openggcm and xl < 0.0 and xh > 0.0 and -2 * xh > xl:
                    crd_system = "gse"
                else:
                    crd_system = "other"
            except KeyError as e:
                logger.warning("Could not determine coordiname system; "
                               "either the logfile is mangled, or "
                               "the libmrc options I'm using in infer "
                               "crd system have changed (%s)", e.args[0])

        if crd_system is None:
            crd_system = "unknown"

        if freshly_determined_crd_system:
            self.set_info("crd_system", crd_system)

        # now that we have an idea what the crd_system is, determine
        # whether or not to do a mhd -> gse translation

        request = str(self.mhd_to_gse_on_read).strip().lower()

        if request == 'true':
            viscid.logger.warning("'mhd_to_gse_on_read = true' is deprecated due "
                                  "to lack of clarity. Please use 'auto', or if "
                                  "you really want to always flip the axes, use "
                                  "'force'. Only use 'force' if you are certain, "
                                  "since even non-magnetosphere OpenGGCM grids "
                                  "will be flipped, and you will be confused "
                                  "some day when you open an MHD-in-a-box run, "
                                  "and you have forgetten about this message.")
            ret = True
        elif request == 'force':
            ret = True
        elif request == 'false':
            ret = False
        elif request.startswith("auto"):
            default = True if request.endswith('true') else False
            if crd_system == "mhd":
                ret = True
            elif crd_system == "gse":
                ret = False
            else:
                log_fname = self.find_info("_viscid_log_fname")
                # which error / warning to print depends on why crd_system
                # neither mhd | gse; was logfile reading turned off, was
                # the logfile not found, or was the logfile simply mangled?
                if default:
                    default_action = "flipping axes since default is True"
                else:
                    default_action = "not flipping axes since default is False"

                if log_fname is False:
                    logger.error("If you're using 'auto' for mhd->gse "
                                 "conversion, reading the logfile MUST be "
                                 "turned on. ({0})".format(default_action))
                elif log_fname is None:
                    logger.warning("Tried to determine coordinate system using "
                                   "logfile parameters, but no logfile found. "
                                   "Copy over the log file to use auto mhd->gse "
                                   "conversion. ({0})".format(default_action))
                else:
                    logger.warning("Could not determine crd_system used for this "
                                   "grid on disk ({0})".format(default_action))
                # crd_system is either 'other' or 'unknown'
                ret = default
        else:
            raise ValueError("Invalid value for mhd_to_gse_on_read: "
                             "'{0}'; valid choices: (True, False, auto, "
                             "auto_true, force)".format(request))

        self.set_info("_viscid_do_mhd_to_gse_on_read", ret)
        return ret
コード例 #18
0
ファイル: multiplot.py プロジェクト: KristoforMaynard/Viscid
def _do_multiplot(tind,
                  grid,
                  plot_vars=None,
                  global_popts=None,
                  kwopts=None,
                  share_axes=False,
                  show=False,
                  subplot_params=None,
                  first_run_result=None,
                  first_run=False,
                  **kwargs):
    from viscid.plot import vpyplot as vlt
    import matplotlib.pyplot as plt

    logger.info("Plotting timestep: %d, %g", tind, grid.time)

    if plot_vars is None:
        raise ValueError("No plot_vars given to `_do_multiplot` :(")
    if kwargs:
        logger.info("Unused kwargs: {0}".format(kwargs))

    if kwopts is None:
        kwopts = {}
    transpose = kwopts.get("transpose", False)
    plot_size = kwopts.get("plot_size", None)
    dpi = kwopts.get("dpi", None)
    out_prefix = kwopts.get("out_prefix", None)
    out_format = kwopts.get("out_format", "png")
    selection = kwopts.get("selection", None)
    timeformat = kwopts.get("timeformat", ".02f")
    tighten = kwopts.get("tighten", False)
    # wicked hacky
    # subplot_params = kwopts.get("subplot_params", _subplot_params)

    # nrows = len(plot_vars)
    nrows = len([pv[0] for pv in plot_vars if not pv[0].startswith('^')])
    ncols = 1
    if transpose:
        nrows, ncols = ncols, nrows

    if nrows == 0:
        logger.warning("I have no variables to plot")
        return

    fig = plt.gcf()
    if plot_size is not None:
        fig.set_size_inches(*plot_size, forward=True)
    if dpi is not None:
        fig.set_dpi(dpi)

    shareax = None

    this_row = -1
    for i, fld_meta in enumerate(plot_vars):
        if not fld_meta[0].startswith('^'):
            this_row += 1
            same_axis = False
        else:
            same_axis = True

        fld_name_meta = fld_meta[0].lstrip('^')
        fld_name_split = fld_name_meta.split(',')
        if '=' in fld_name_split[0]:
            # if fld_name is actually an equation, assume
            # there's no slice, and commas are part of the
            # equation
            fld_name = ",".join(fld_name_split)
            fld_slc = ""
        else:
            fld_name = fld_name_split[0]
            fld_slc = ",".join(fld_name_split[1:])
        if selection is not None:
            # fld_slc += ",{0}".format(selection)
            if fld_slc != "":
                fld_slc = ",".join([fld_slc, selection])
            else:
                fld_slc = selection
        if fld_slc.strip() == "":
            fld_slc = Ellipsis

        # print("fld_time:", fld.time)
        if this_row < 0:
            raise ValueError("first plot can't begin with a +")
        row = this_row
        col = 0
        if transpose:
            row, col = col, row
        if not same_axis:
            ax = plt.subplot2grid((nrows, ncols), (row, col),
                                  sharex=shareax,
                                  sharey=shareax)
        if i == 0 and share_axes:
            shareax = ax

        if "plot_opts" not in fld_meta[1]:
            fld_meta[1]["plot_opts"] = global_popts
        elif global_popts is not None:
            fld_meta[1]["plot_opts"] = "{0},{1}".format(
                fld_meta[1]["plot_opts"], global_popts)

        with grid.get_field(fld_name, slc=fld_slc) as fld:
            vlt.plot(fld, masknan=True, **fld_meta[1])
        # print("fld cache", grid[fld_meta[0]]._cache)

    if timeformat and timeformat.lower() != "none":
        plt.suptitle(grid.format_time(timeformat))

    # for adjusting subplots / tight_layout and applying the various
    # hacks to keep plots from dancing around in movies
    if not subplot_params and first_run_result:
        subplot_params = first_run_result
    if tighten:
        tighten = dict(rect=[0, 0.03, 1, 0.90])
    ret = vlt.auto_adjust_subplots(tight_layout=tighten,
                                   subplot_params=subplot_params)
    if not first_run:
        ret = None

    if out_prefix:
        plt.savefig("{0}_{1:06d}.{2}".format(out_prefix, tind + 1, out_format))
    if show:
        plt.show()
    plt.clf()

    return ret
コード例 #19
0
def calc_psi(B, rev=False):
    """Calc Flux function (only valid in 2d)

    Parameters:
        B (VectorField): magnetic field, should only have two
            spatial dimensions so we can infer the symmetry dimension
        rev (bool): since this integration doesn't like going
            through undefined regions (like within 1 earth radius of
            the origin for openggcm), you can use this to start
            integrating from the opposite corner.

    Returns:
        ScalarField: 2-D scalar flux function

    Raises:
        ValueError: If B has <> 2 spatial dimensions

    """
    # TODO: if this is painfully slow, i bet just putting this exact
    # code in a cython module would make it a bunch faster, the problem
    # being that the loops are in python instead of some broadcasting
    # numpy type thing

    B = B.slice_reduce(":")

    # try to guess if a dim of a 3D field is invariant
    reduced_axes = []
    if B.nr_sdims > 2:
        slcs = [slice(None)] * B.nr_sdims
        for i, nxi in enumerate(B.sshape):
            if nxi <= 2:
                slcs[i] = 0
                reduced_axes.append(B.crds.axes[i])
        slcs.insert(B.nr_comp, slice(None))
        B = B[slcs]

    # ok, so the above didn't work... just nip out the smallest dim?
    if B.nr_sdims == 3:
        slcs = [slice(None)] * B.nr_sdims
        i = np.argmin(B.sshape)
        slcs[i] = 0
        reduced_axes.append(B.crds.axes[i])
        logger.warning("Tried to get the flux function of a 3D field. "
                       "I can't do that, so I'm\njust ignoring the {0} "
                       "dimension".format(reduced_axes[-1]))
        slcs.insert(B.nr_comp, slice(None))
        B = B[slcs]

    if B.nr_sdims != 2:
        raise ValueError("flux function only implemented for 2D fields")

    comps = ""
    for comp in "xyz":
        if comp in B.crds.axes:
            comps += comp
    # ex: comps = "yz", comp_inds = [1, 2]
    comp_inds = [dict(x=0, y=1, z=2)[comp] for comp in comps]

    # Note: what follows says y, z, but it has been generalized
    # to any two directions, so hy isn't necessarily hy, but it's
    # easier to see at a glance if it's correct using a specific
    # example
    ycc, zcc = B.get_crds(comps)
    comp_views = B.component_views()
    hy, hz = comp_views[comp_inds[0]], comp_views[comp_inds[1]]
    dy = ycc[1:] - ycc[:-1]
    dz = zcc[1:] - zcc[:-1]
    ny, nz = len(ycc), len(zcc)

    A = np.empty((ny, nz), dtype=B.dtype)

    if rev:
        A[-1, -1] = 0.0
        for i in range(ny - 2, -1, -1):
            A[i, -1] = A[i + 1, -1] - dy[i] * 0.5 * (hz[i, -1] + hz[i + 1, -1])

        for j in range(nz - 2, -1, -1):
            A[:, j] = A[:, j + 1] + dz[j] * 0.5 * (hy[:, j + 1] + hy[:, j])
    else:
        A[0, 0] = 0.0
        for i in range(1, ny):
            A[i, 0] = A[i - 1, 0] + dy[i - 1] * 0.5 * (hz[i, 0] + hz[i - 1, 0])

        for j in range(1, nz):
            A[:, j] = A[:, j - 1] - dz[j - 1] * 0.5 * (hy[:, j - 1] + hy[:, j])

    psi = field.wrap_field(A, B.crds, name="psi", center=B.center,
                           pretty_name=r"$\psi$", parents=[B])
    if reduced_axes:
        slc = "..., " + ", ".join("{0}=None".format(ax) for ax in reduced_axes)
        psi = psi[slc]
    return psi
コード例 #20
0
ファイル: xdmf.py プロジェクト: KristoforMaynard/Viscid
    def _parse_geometry(self, geo, topoattrs):
        """ geo is the element tree item, returns Coordinate object and
            xml attributes """
        geoattrs = self._fill_attrs(geo)
        # crds = None
        crdlist = None
        crdtype = None
        crdkwargs = {}

        topotype = topoattrs["TopologyType"]

        # parse geometry into crds
        geotype = geoattrs["GeometryType"]
        if geotype.upper() == "XYZ":
            data, attrs = self._parse_dataitem(geo.find("./DataItem"),
                                               keep_flat=True)
            # x = data[0::3]
            # y = data[1::3]
            # z = data[2::3]
            # crdlist = (('z', z), ('y', y), ('x', x))
            # quietly do nothing... we don't support unstructured grids
            # or 3d spherical yet, and 2d spherical can be figured out
            # if we assume the grid spans the whole sphere
            crdlist = None

        elif geotype.upper() == "XY":
            data, attrs = self._parse_dataitem(geo.find("./DataItem"),
                                               keep_flat=True)
            # x = data[0::2]
            # y = data[1::2]
            # z = np.zeros(len(x))
            # crdlist = (('z', z), ('y', y), ('x', x))
            # quietly do nothing... we don't support unstructured grids
            # or 3d spherical yet, and 2d spherical can be figured out
            # if we assume the grid spans the whole sphere
            crdlist = None

        elif geotype.upper() == "X_Y_Z":
            crdlookup = {'x': 0, 'y': 1, 'z': 2}
            crdlist = [['x', None], ['y', None], ['z', None]]
            # can't use ./DataItem[@Name='X'] so python2.6 works
            dataitems = geo.findall("./DataItem")
            for di in dataitems:
                crd_name = di.attrib["Name"].lower()
                data, attrs = self._parse_dataitem(di, keep_flat=True)
                crdlist[crdlookup.pop(crd_name)][1] = data
            if len(crdlookup) > 0:
                raise RuntimeError("XDMF format error: Coords not specified "
                                   "for {0} dimesions"
                                   "".format(list(crdlookup.keys())))
            crdkwargs["full_arrays"] = True

        elif geotype.upper() == "VXVYVZ":
            crdlookup = {'x': 0, 'y': 1, 'z': 2}
            crdlist = [['x', None], ['y', None], ['z', None]]
            # can't use ./DataItem[@Name='VX'] so python2.6 works
            dataitems = geo.findall("./DataItem")
            for di in dataitems:
                crd_name = di.attrib["Name"].lstrip('V').lower()
                data, attrs = self._parse_dataitem(di, keep_flat=True)
                crdlist[crdlookup.pop(crd_name)][1] = data
            if len(crdlookup) > 0:
                raise RuntimeError("XDMF format error: Coords not specified "
                                   "for {0} dimesions"
                                   "".format(list(crdlookup.keys())))
            crdkwargs["full_arrays"] = True

        elif geotype.upper() == "ORIGIN_DXDYDZ":
            # this is for grids with uniform spacing
            dataitems = geo.findall("./DataItem")
            data_o, _ = self._parse_dataitem(dataitems[0])
            data_dx, _ = self._parse_dataitem(dataitems[1])
            dtyp = data_o.dtype
            nstr = None
            if topoattrs["Dimensions"]:
                nstr = topoattrs["Dimensions"]
            elif topoattrs["NumberOfElements"]:
                nstr = topoattrs["NumberOfElements"]
            else:
                raise ValueError("ORIGIN_DXDYDZ has no number of elements...")
            n = [int(num) for num in nstr.split()]
            # FIXME: OpenGGCM output uses ZYX ordering even though the xdmf
            # website says it should be XYZ, BUT, the file opens correctly
            # in Paraview with zyx, so... I guess i need to do this [::-1]
            # nonsense here
            data_o, data_dx, n = data_o[::-1], data_dx[::-1], n[::-1]
            crdlist = [None] * 3
            for i, crd in enumerate(['x', 'y', 'z']):
                n_nc, n_cc = n[i], n[i] - 1
                crd_arr = [data_o[i], data_o[i] + (n_cc * data_dx[i]), n_nc]
                crdlist[i] = (crd, crd_arr)
            crdkwargs["dtype"] = dtyp
            crdkwargs["full_arrays"] = False
        else:
            logger.warning("Invalid GeometryType: %s", geotype)

        if topotype in ['3DCoRectMesh', '2DCoRectMesh']:
            crdtype = "uniform_cartesian"
        elif topotype in ['3DRectMesh', '2DRectMesh']:
            if crdkwargs.get("full_arrays", True):
                crdtype = "nonuniform_cartesian"
            else:  # HACK, hopefully not used ever
                crdtype = "uniform_cartesian"
        elif topotype in ['2DSMesh']:
            crdtype = "uniform_spherical"  # HACK!
            ######## this doesn't quite work, but it's too heavy to be useful
            ######## anyway... if we assume a 2d spherical grid spans the
            ######## whole sphere, and radius doesnt matter, all we need are
            ######## the nr_phis / nr_thetas, so let's just do that
            # # this asserts that attrs["Dimensions"] will have the xyz
            # # dimensions
            # # turn x, y, z -> phi, theta, r
            # dims = [int(s) for
            #         s in reversed(topoattrs["Dimensions"].split(' '))]
            # dims = [1] * (3 - len(dims)) + dims
            # nr, ntheta, nphi = [d for d in dims]
            # # dtype = crdlist[0][1].dtype
            # # phi, theta, r = [np.empty((n,), dtype=dtype) for n in dims]
            # x, y, z = (crdlist[i][1].reshape(dims) for i in range(3))
            # nphitheta = nphi * ntheta
            # r = np.sqrt(x[::nphitheta, 0, 0]**2 + y[::nphitheta, 0, 0]**2 +
            #             z[::nphitheta, 0, 0]**2)
            # ir = nr // 2  # things get squirrly near the extrema
            # theta = (180.0 / np.pi) * \
            #         (np.arccos(z[ir, :, ::nphi] / r[ir]).reshape(-1))
            # itheta = ntheta // 2
            # phi = (180.0 / np.pi) * \
            #       np.arctan2(y[ir, itheta, :], x[ir, itheta, :])
            # print(dims, nr, ntheta, nphi)
            # print("r:", r.shape, r)
            # print("theta:", theta.shape, theta)
            # print("phi:", phi.shape, phi)
            # raise RuntimeError()
            ######## general names in spherical crds
            # ntheta, nphi = [int(s) for s in topoattrs["Dimensions"].split(' ')]
            # crdlist = [['theta', [0.0, 180.0, ntheta]],
            #            ['phi', [0.0, 360.0, nphi]]]
            ######## names on a map
            ntheta, nphi = [int(s) for s in topoattrs["Dimensions"].split(' ')]
            crdlist = [['phi', [0.0, 360.0, nphi]],
                       ['theta', [0.0, 180.0, ntheta]]]
            crdkwargs["full_arrays"] = False
            crdkwargs["units"] = 'deg'

        elif topotype in ['3DSMesh']:
            raise NotImplementedError("3D spherical grids not yet supported")
        else:
            raise NotImplementedError("Unstructured grids not yet supported")

        crds = coordinate.wrap_crds(crdtype, crdlist, **crdkwargs)
        return crds, geoattrs
コード例 #21
0
    def load_files(self,
                   fnames,
                   index_handle=True,
                   file_type=None,
                   prefer=None,
                   force_reload=False,
                   _add_ref=False,
                   **kwargs):
        """Load files, and add them to the bucket

        Initialize obj before it's put into the list, whatever is returned
        is what gets stored, returning None means object init failed, do
        not add to the _objs list

        Parameters:
            fnames: a list of file names (can cantain glob patterns)
            index_handle: ??
            file_type: a class that is a subclass of VFile, if given,
                use this file type, don't use the autodetect mechanism
            kwargs: passed to file constructor

        Returns:
            A list of VFile instances. The length may not be the same
            as the length of fnames, and the order may not be the same
            in order to accomidate globs and file grouping.
        """
        orig_fnames = fnames

        if not isinstance(fnames, (list, tuple)):
            fnames = [fnames]
        file_lst = []

        # glob and convert to absolute paths
        globbed_fnames = []
        for fname in fnames:
            slglob = slice_globbed_filenames(fname)
            if isinstance(slglob, string_types):
                slglob = [slglob]
            globbed_fnames += slglob
            # print(">>", fname)
            # print("==", globbed_fnames)
            # expanded_fname = os.path.expanduser(os.path.expandvars(fname))
            # absfname = os.path.abspath(expanded_fname)
            # if '*' in absfname or '?' in absfname:
            #     globbed_fnames += glob(absfname)
            # else:
            #     globbed_fnames += [absfname]
            # Is it necessary to recall abspath here? We did it before
            # the glob to make sure it didn't start with a '.' since that
            # tells glob not to fill wildcards
        fnames = globbed_fnames

        # detect file types
        types_detected = OrderedDict()
        for i, fname in enumerate(fnames):
            _ftype = None
            if file_type is None:
                _ftype = VFile.detect_type(fname, prefer=prefer)
            elif isinstance(file_type, string_types):
                _ftype = VFile.resolve_type(file_type)
            else:
                _ftype = file_type
            if not _ftype:
                raise RuntimeError("Can't determine type "
                                   "for {0}".format(fname))
            value = (fname, i)
            try:
                types_detected[_ftype].append(value)
            except KeyError:
                types_detected[_ftype] = [value]

        # see if the file's already been loaded, or load it, and add it
        # to the bucket and all that good stuff
        file_lst = []
        for ftype, vals in types_detected.items():
            names = [v[0] for v in vals]
            # group all file names of a given type
            groups = ftype.group_fnames(names)

            # iterate all the groups and add them
            for group in groups:
                f = None

                handle_name = ftype.collective_name(group)

                try:
                    f = self[handle_name]
                    if force_reload:
                        f.reload()
                except KeyError:
                    try:
                        f = ftype(group, parent_bucket=self, **kwargs)
                        f.handle_name = handle_name
                    except IOError as e:
                        s = " IOError on file: {0}\n".format(handle_name)
                        s += "              File Type: {0}\n".format(
                            handle_name)
                        s += "              {0}".format(str(e))
                        logger.warning(s)
                    except ValueError as e:
                        # ... why am i explicitly catching ValueErrors?
                        # i'm probably breaking something by re-raising
                        # this exception, but i didn't document what :(
                        s = " ValueError on file load: {0}\n".format(
                            handle_name)
                        s += "              File Type: {0}\n".format(
                            handle_name)
                        s += "              {0}".format(str(e))
                        logger.warning(s)
                        # re-raise the last expection
                        raise

                self.set_item([handle_name],
                              f,
                              index_handle=index_handle,
                              _add_ref=_add_ref)
                file_lst.append(f)

        if len(file_lst) == 0:
            logger.warning("No files loaded for '{0}', is the path "
                           "correct?".format(orig_fnames))
        return file_lst
コード例 #22
0
def std_sel2index(std_sel,
                  crd_arr,
                  val_endpoint=True,
                  interior=False,
                  tdunit='s',
                  epoch=None):
    """Turn single standardized selection into slice (by int or None)

    Normally (val_endpoint=True, interior=False), the rules for float
    lookup val_endpoints are::

        - The slice will never include an element whose value in arr
          is < start (or > if the slice is backward)
        - The slice will never include an element whose value in arr
          is > stop (or < if the slice is backward)
        - !! The slice WILL INCLUDE stop if you don't change
          val_endpoint. This is different from normal slicing, but
          it's more natural when specifying a slice as a float.

    If interior=True, then the slice is expanded such that start and
    stop are interior to the sliced array.

    Args:
        std_sel: single standardized selection
        arr (ndarray): filled with floats to do the lookup
        val_endpoint (bool): iff True then include stop in the slice when
            slicing-by-value (DOES NOT EFFECT SLICE-BY-INDEX).
            Set to False to get python slicing symantics when it
            comes to excluding stop, but fair warning, python
            symantics feel awkward here. Consider the case
            [0.1, 0.2, 0.3][:0.25]. If you think this should include
            0.2, then leave keep val_endpoint=True.
        interior (bool): if True, then extend both ends of the slice
            such that slice-by-location endpoints are interior to the
            slice
        epoch (datetime64-like): Epoch for to go datetime64 <-> float
        tdunit (str): Presumed time unit for floats
        tol (int): number of machine epsilons to consider
            "close enough"
    """
    idx = None

    if interior and not val_endpoint:
        logger.warning("For interior slices, val_endpoint must be True, I'll "
                       "change that for you.")
        val_endpoint = True

    if isinstance(std_sel, slice):
        assert isinstance(std_sel.step, (int, np.integer, type(None)))
        start_val = None
        stop_val = None

        orig_step = std_sel.step
        ustep = 1 if std_sel.step is None else int(std_sel.step)
        sgn = np.sign(ustep)

        if (isinstance(std_sel.start, (int, np.integer, type(None)))
                and not isinstance(std_sel.start,
                                   (np.datetime64, np.timedelta64))):
            ustart = std_sel.start
        else:
            ustart, tol = _unify_sbv_types(std_sel.start,
                                           crd_arr,
                                           tdunit='s',
                                           epoch=epoch)
            start_val = ustart
            diff = crd_arr - ustart + (tol * sgn)
            zero = np.array([0]).astype(diff.dtype)[0]

            if ustep > 0:
                diff = np.ma.masked_less(diff, zero)
            else:
                diff = np.ma.masked_greater(diff, zero)

            if np.ma.count(diff) == 0:
                # start value is past the wrong end of the array
                if ustep > 0:
                    ustart = len(crd_arr)
                else:
                    # start = -len(arr) - 1
                    # having a value < -len(arr) won't play
                    # nice with make_fwd_slice, but in this
                    # case, the slice will have no data, so...
                    return slice(0, 0, ustep)
            else:
                ustart = np.argmin(np.abs(diff))

        if (isinstance(std_sel.stop, (int, np.integer, type(None)))
                and not isinstance(std_sel.stop,
                                   (np.datetime64, np.timedelta64))):
            ustop = std_sel.stop
        else:
            ustop, tol = _unify_sbv_types(std_sel.stop,
                                          crd_arr,
                                          tdunit='s',
                                          epoch=epoch)
            stop_val = ustop
            diff = crd_arr - ustop - (tol * sgn)
            zero = np.array([0]).astype(diff.dtype)[0]

            if ustep > 0:
                diff = np.ma.masked_greater(diff, zero)
            else:
                diff = np.ma.masked_less(diff, zero)

            if ustep > 0:
                if ustop < crd_arr[0]:
                    # stop value is past the wong end of the array
                    ustop = 0
                else:
                    ustop = int(np.argmin(np.abs(diff)))
                    if val_endpoint:
                        ustop += 1
            else:
                if ustop > crd_arr[-1]:
                    # stop value is past the wrong end of the array
                    ustop = len(crd_arr)
                else:
                    ustop = int(np.argmin(np.abs(diff)))
                    if val_endpoint:
                        if ustop > 0:
                            ustop -= 1
                        else:
                            # 0 - 1 == -1 which would wrap to the end of
                            # of the array... instead, just make it None
                            ustop = None
        idx = slice(ustart, ustop, orig_step)

        if interior:
            _a, _b, _c = _interiorize_slice(crd_arr, start_val, stop_val,
                                            idx.start, idx.stop, idx.step)
            idx = slice(_a, _b, _c)

    else:
        # slice by single value or ndarray of single values (int, float, times)
        usel, _ = _unify_sbv_types(std_sel, crd_arr, tdunit='s', epoch=epoch)
        if (isinstance(usel, (int, np.integer, type(None)))
                and not isinstance(usel, (np.datetime64, np.timedelta64))):
            idx = usel
        elif isinstance(usel, np.ndarray):
            if isinstance(usel[0, 0], np.integer):
                idx = usel.reshape(-1)
            else:
                idx = np.argmin(np.abs(crd_arr.reshape(-1, 1) - usel), axis=0)
        else:
            idx = np.argmin(np.abs(crd_arr - usel))

    return idx