Esempio n. 1
0
    def test_init(self):
        """Test initialisation with lists, tuples, dicts of arrays
        rather than Columns [regression test for #2647]"""
        x1 = np.arange(10.)
        x2 = np.arange(5.)
        x3 = np.arange(7.)
        col_list = [('x1', x1), ('x2', x2), ('x3', x3)]
        tc_list = TableColumns(col_list)
        for col in col_list:
            assert col[0] in tc_list
            assert tc_list[col[0]] is col[1]

        col_tuple = (('x1', x1), ('x2', x2), ('x3', x3))
        tc_tuple = TableColumns(col_tuple)
        for col in col_tuple:
            assert col[0] in tc_tuple
            assert tc_tuple[col[0]] is col[1]

        col_dict = dict([('x1', x1), ('x2', x2), ('x3', x3)])
        tc_dict = TableColumns(col_dict)
        for col in tc_dict.keys():
            assert col in tc_dict
            assert tc_dict[col] is col_dict[col]

        columns = [Column(col[1], name=col[0]) for col in col_list]
        tc = TableColumns(columns)
        for col in columns:
            assert col.name in tc
            assert tc[col.name] is col
Esempio n. 2
0
    def __new__(cls, data=None, name=None, dtype=None, shape=(), length=0, description=None, unit=None, format=None, meta=None, cols=None):

        self = super(AtmoGrid, cls).__new__(cls, data=data, name=name, dtype=dtype, shape=shape, length=length, description=description, unit=unit, format=format, meta=meta)

        if cols is None:
            self.cols = TableColumns([ Column(name=str(i), data=np.arange(self.shape[i], dtype=float)) for i in range(self.ndim) ])
        else:
            if len(cols) != self.ndim:
                raise ValueError('cols must contain a number of elements equal to the dimension of the data grid.')
            else:
                if isinstance(cols, TableColumns):
                    self.cols = cols
                else:
                    try:
                        self.cols = TableColumns([ Column(name=col[0], data=col[1]) if isinstance(col, (list,tuple)) else col for col in cols ])
                    except:
                        raise ValueError('Cannot make a TableColumns out of the provided cols parameter.')
            shape = tuple(col.size for col in self.cols.itervalues())
            if self.shape != shape:
                raise ValueError('The dimension of the data grid and the cols are not matching.')
        return self
Esempio n. 3
0
    def __new__(cls, data=None, name=None, dtype=None, shape=(), length=0, description=None, unit=None, format=None, meta=None, cols=None):

        self = super(AtmoGrid, cls).__new__(cls, data=data, name=name, dtype=dtype, shape=shape, length=length, description=description, unit=unit, format=format, meta=meta)

        if cols is None:
            self.cols = TableColumns([ Column(name=str(i), data=np.arange(self.shape[i], dtype=float)) for i in range(self.ndim) ])
        else:
            if len(cols) != self.ndim:
                raise ValueError('cols must contain a number of elements equal to the dimension of the data grid.')
            else:
                if isinstance(cols, TableColumns):
                    self.cols = cols
                else:
                    try:
                        self.cols = TableColumns([ Column(name=col[0], data=col[1]) if isinstance(col, (list,tuple)) else col for col in cols ])
                    except:
                        raise ValueError('Cannot make a TableColumns out of the provided cols parameter.')
            shape = tuple(col.size for col in self.cols.itervalues())
            if self.shape != shape:
                raise ValueError('The dimension of the data grid and the cols are not matching.')
        return self
Esempio n. 4
0
    def test_init(self):
        """Test initialisation with lists, tuples, dicts of arrays
        rather than Columns [regression test for #2647]"""
        x1 = np.arange(10.)
        x2 = np.arange(5.)
        x3 = np.arange(7.)
        col_list = [('x1', x1), ('x2', x2), ('x3', x3)]
        tc_list = TableColumns(col_list)
        for col in col_list:
            assert col[0] in tc_list
            assert tc_list[col[0]] is col[1]

        col_tuple = (('x1', x1), ('x2', x2), ('x3', x3))
        tc_tuple = TableColumns(col_tuple)
        for col in col_tuple:
            assert col[0] in tc_tuple
            assert tc_tuple[col[0]] is col[1]

        col_dict = dict([('x1', x1), ('x2', x2), ('x3', x3)])
        tc_dict = TableColumns(col_dict)
        for col in tc_dict.keys():
            assert col in tc_dict
            assert tc_dict[col] is col_dict[col]

        columns = [Column(col[1], name=col[0]) for col in col_list]
        tc = TableColumns(columns)
        for col in columns:
            assert col.name in tc
            assert tc[col.name] is col
Esempio n. 5
0
class AtmoGrid(Column):
    """
    Define the base atmosphere grid structure.

    AtmoGrid contains utilities to trim the grid, read/write to HDF5 format.

    Parameters
    ----------
    data : ndarray
        Grid of log(flux) values (e-base)
    name : str
        Keyword name of the atmosphere grid
    dtype : np.dtype compatible value
        Data type the flux grid
    shape : tuple or ()
        Dimensions of a single row element in the flux grid
    length : int or 0
        Number of row elements in the grid
    description : str or None
        Full description of the atmosphere grid
    unit : str or None
        Physical unit
    format : str or None or function or callable
        Format string for outputting column values.  This can be an
        "old-style" (``format % value``) or "new-style" (`str.format`)
        format specification string or a function or any callable object that
        accepts a single value and returns a string.
    meta : dict-like or None
        Meta-data associated with the atmosphere grid
    cols : OrderedDict-like, list of Columns, list of lists/tuples
        Full definition of the flux grid axes. This can be a list of entries
        ('colname', ndarray) with the ndarray corresponding to the axis values
        or a list of Columns containing this information.

    Examples
    --------
    A AtmoGrid can be created like this:

      Examples::

        logtemp = np.log(np.arange(3000.,10001.,250.))
        logg = np.arange(2.0, 5.6, 0.5)
        mu = np.arange(0.,1.01,0.02)
        logflux = np.random.normal(size=(logtemp.size,logg.size,mu.size))
        atmo = AtmoGrid(data=logflux, cols=[('logtemp',logtemp), ('logg',logg), ('mu',mu)])

    To read/save a file:

        atmo = AtmoGridPhot.ReadHDF5('vband.h5')
        atmo.WriteHDF5('vband_new.h5')

    Notes
    --------------
    Note that in principle the axis and data could be any format. However, we recommend using
    log(flux) and log(temperature) because the linear interpolation of such a grid would make
    more sense (say, from the blackbody $F \propto sigma T^4$).

    """
    def __new__(cls, data=None, name=None, dtype=None, shape=(), length=0, description=None, unit=None, format=None, meta=None, cols=None):

        self = super(AtmoGrid, cls).__new__(cls, data=data, name=name, dtype=dtype, shape=shape, length=length, description=description, unit=unit, format=format, meta=meta)

        if cols is None:
            self.cols = TableColumns([ Column(name=str(i), data=np.arange(self.shape[i], dtype=float)) for i in range(self.ndim) ])
        else:
            if len(cols) != self.ndim:
                raise ValueError('cols must contain a number of elements equal to the dimension of the data grid.')
            else:
                if isinstance(cols, TableColumns):
                    self.cols = cols
                else:
                    try:
                        self.cols = TableColumns([ Column(name=col[0], data=col[1]) if isinstance(col, (list,tuple)) else col for col in cols ])
                    except:
                        raise ValueError('Cannot make a TableColumns out of the provided cols parameter.')
            shape = tuple(col.size for col in self.cols.itervalues())
            if self.shape != shape:
                raise ValueError('The dimension of the data grid and the cols are not matching.')
        return self

    def __copy__(self):
        return self.copy(copy_data=False)

    def __deepcopy__(self):
        return self.copy(copy_data=True)

    def __getitem__(self, item):
        if isinstance(item, six.string_types):
            if item not in self.colnames:
                if 'log'+item in self.colnames:
                    return np.exp(self.cols['log'+item])
                elif 'log'+item[2:] in self.colnames:
                    return 10**(self.cols['log'+item[2:]])
                elif 'log10'+item in self.colnames:
                    return 10**(self.cols['log10'+item])
                else:
                    raise Exception('The provided column name is cannot be found.')
            else:
                return self.cols[item]
        else:
            #return super(AtmoGrid, self).__getitem__(item)
            return self.view(np.ndarray)[item]

    @property
    def colnames(self):
        return self.cols.keys()

    def copy(self, order='C', data=None, copy_data=True):
        """
        Copy of the instance. If ``data`` is supplied
        then a view (reference) of ``data`` is used, and ``copy_data`` is ignored.
        """
        if data is None:
            data = self.view(np.ndarray)
            if copy_data:
                data = data.copy(order)

        return self.__class__(name=self.name, data=data, unit=self.unit, format=self.format, description=self.description, meta=deepcopy(self.meta), cols=self.cols)

    def Fill_nan_old(self, axis=0, method='spline', bounds_error=False, fill_value=np.nan, k=1, s=1):
        """
        Fill the empty grid cells (marked as np.nan) with interpolated values
        along a given axis (i.e. interpolation is done in 1D).

        Parameters
        ----------
        axis : interpolate
            Axis along which the interpolation should be performed.
        method : str
            Interpolation method to use. Possible choices are 'spline' and
            'interp1d'.
            'spline' allows for the use of optional keywords k (the order) and
            s (the smoothing parameter). See scipy.interpolate.splrep.
            'interp1d' allows for the use of optional keywords bounds_error and
            fill_value. See scipy.interpolate.interp1d.
        bounds_error : bool
            Whether to raise an error when attempting to extrapolate out of
            bounds. Only works with 'interp1d'.
        fill_value : float
            Value to use when bounds_error is False. Only works with 'interp1d'.
        k : int
            Order of the spline to use. We recommend 1. Only works with
            'spline'.
        s : int
            Smoothing parameter for the spline. We recommend 0 (exact
            interpolation), or 1. Only works with 'spline'.

        Examples
        ----------
          Examples::
            atmo.Fill_nan(axis=0, method='interp1d', bounds_error=False, fill_value=np.nan)

        This would fill in the value that are not out of bound with a linear fit. Values
        out of bound would be np.nan.

            atmo.Fill_nan(axis=0, method='spline', k=1, s=0)

        This would produce exactly the same interpolation as above, except that values
        out of bound would be extrapolated.

        Notes
        ----------
        From our experience, it is recommended to first fill the values within
        the bounds using 'interp1d' with bounds_error=False and fill_value=np.nan,
        and then use 'spline' with k=1 and s=1 in order to extrapolate outside
        the bounds. To interpolate within the bounds, the temperature axis
        (i.e. 0) is generally best and more smooth, whereas the logg axis (i.e. 1)
        works better to extrapolate outside.


          Examples::
            atmo.Fill_nan(axis=0, method='interp1d', bounds_error=False, fill_value=np.nan)
            atmo.Fill_nan(axis=1, method='spline', k=1, s=1)
        """
        if method not in ['interp1d','spline']:
            raise Exception('Wrong method input! Must be either interp1d, spline or grid.')
        ndim = list(self.shape)
        ndim.pop(axis)
        inds_tmp = np.indices(ndim)
        inds = [ind.flatten() for ind in inds_tmp]
        niter = len(inds[0])
        inds.insert(axis, [slice(None)]*niter)
        print(inds)
        for ind in zip(*inds):
            col = self.__getitem__(ind)
            inds_good = np.isfinite(col)
            inds_bad = ~inds_good
            if np.any(inds_bad):
                if method == 'interp1d':
                    interpolator = scipy.interpolate.interp1d(self.cols[axis][inds_good], col[inds_good], assume_sorted=True, bounds_error=bounds_error, fill_value=fill_value)
                    col[inds_bad] = interpolator(self.cols[axis][inds_bad])
                elif method == 'spline':
                    tck = scipy.interpolate.splrep(self.cols[axis][inds_good], col[inds_good], k=k, s=s)
                    col[inds_bad] = scipy.interpolate.splev(self.cols[axis][inds_bad], tck)

    def Fill_nan(self, axis=0, inds_fill=None, method='spline', bounds_error=False, fill_value=np.nan, k=1, s=0, extrapolate=True):
        """
        Fill the empty grid cells (marked as np.nan) with interpolated values
        along a given axis (i.e. interpolation is done in 1D).

        Parameters
        ----------
        axis : interpolate
            Axis along which the interpolation should be performed.
        inds_fill : tuple(ndarray)
            Tuple/list containing the list of pixels to interpolate for.
        method : str
            Interpolation method to use. Possible choices are 'spline',
            'interp1d' and 'pchip'.
            'spline' allows for the use of optional keywords k (the order) and
            s (the smoothing parameter). See scipy.interpolate.splrep.
            'interp1d' allows for the use of optional keywords bounds_error and
            fill_value. See scipy.interpolate.interp1d.
            'pchip' allows to interpolate out of bound, or set NaNs.
        bounds_error : bool
            Whether to raise an error when attempting to extrapolate out of
            bounds.
            Only works with 'interp1d'.
        fill_value : float
            Value to use when bounds_error is False.
            Only works with 'interp1d'.
        k : int
            Order of the spline to use. We recommend 1.
            Only works with 'spline'.
        s : int
            Smoothing parameter for the spline. We recommend 0 (exact
            interpolation).
            Only works with 'spline'.
        extrapolate : bool
            Whether to extrapolate out of bound or set NaNs.
            Only works with 'pchip'.

        Examples
        ----------
          Examples::
            atmo.Fill_nan(axis=0, method='interp1d', bounds_error=False, fill_value=np.nan)

        This would fill in the value that are not out of bound with a linear fit. Values
        out of bound would be np.nan.

            atmo.Fill_nan(axis=0, method='spline', k=1, s=0)

        This would produce exactly the same interpolation as above, except that values
        out of bound would be extrapolated.

        Notes
        ----------
        From our experience, it is recommended to first fill the values within
        the bounds using 'interp1d' with bounds_error=False and fill_value=np.nan,
        and then use 'spline' with k=1 and s=1 in order to extrapolate outside
        the bounds. To interpolate within the bounds, the temperature axis
        (i.e. 0) is generally best and more smooth, whereas the logg axis (i.e. 1)
        works better to extrapolate outside.


          Examples::
            atmo.Fill_nan(axis=0, method='interp1d', bounds_error=False, fill_value=np.nan)
            atmo.Fill_nan(axis=1, method='spline', k=1, s=1)
        """
        if method not in ['interp1d','spline','pchip']:
            raise Exception('Wrong method input! Must be either interp1d, spline or grid.')
        if inds_fill is None:
            inds_fill = np.isnan(self.data).nonzero()
        else:
            assert len(inds_fill) == self.ndim, "The shape must be (ndim, nfill)."

        vals_fill = []
        for inds_fill_ in zip(*inds_fill):
            #print(inds_fill_)
            inds = list(inds_fill_)
            inds[axis] = slice(None)
            inds = tuple(inds)
            y = self.data[inds]
            x = self.cols[axis]
            x_interp = x[inds_fill_[axis]]
            #print(x)
            #print(y)
            #print(x_interp)
            inds_bad = np.isnan(y)
            inds_bad[x_interp] = True
            inds_good = ~inds_bad
            #print(inds_bad)
            #print(inds_good)
            if np.any(inds_good):
                #print(x[inds_good])
                #print(y[inds_good])
                #print(x_interp)
                if method == 'interp1d':
                    interpolator = scipy.interpolate.interp1d(x[inds_good], y[inds_good], assume_sorted=True, bounds_error=bounds_error, fill_value=fill_value)
                    y_interp = interpolator(x_interp)
                elif method == 'spline':
                    tck = scipy.interpolate.splrep(x[inds_good], y[inds_good], k=k, s=s)
                    y_interp = scipy.interpolate.splev(x_interp, tck)
                elif method == 'pchip':
                    interpolator = scipy.interpolate.PchipInterpolator(x[inds_good], y[inds_good], axis=0, extrapolate=extrapolate)
                    y_interp = interpolator(x_interp)
                #print(y_interp)
                vals_fill.append(y_interp)
            else:
                #print('y_interp -> nan')
                vals_fill.append(np.nan)
            #print('x_interp', x_interp)
            #print('y_interp', y_interp)
        self.data[inds_fill] = vals_fill

    def Getaxispos(self, colname, x):
        """
        Return the index and weight of the linear interpolation of the point
        along a given axis.

        Parameters
        ----------
        colname : str
            Name of the axis to interpolate from.
        x : float, ndarray
            Value to interpolate at.

        Examples
        ----------
          Examples::
            temp = Getaxispos('logtemp', np.log(3550.)
            logg = Getaxispos('logg', [4.11,4.13,4.02])
        """
        if isinstance(x, (list, tuple, np.ndarray)):
            return Utils.Series.Getaxispos_vector(self.cols[colname], x)
        else:
            return Utils.Series.Getaxispos_scalar(self.cols[colname], x)

    @property
    def IsFinite(self):
        return np.isfinite(self.data).astype(int)

    def Pprint(self, slices):
        """
        Print a 2-dimensional slice of the atmosphere grid for visualisation.

        Parameters
        ----------
        slices : list
            List of sliceable elements to extract the 2-dim slice to display.

        Examples
        ----------
          Examples::
            # Display the equivalent of atmo[:,:,4]
            atmo.Pprint([None,None,4])
            # Same as above but using fancier slice objects
            atmo.Pprint([slice(None),slice(None),4])
            # Display the equivalent of atmo[3:9,3,:]
            atmo.Pprint([slice(3,9),3,None])
        """
        slices = list(slices)
        labels = []
        for i,s in enumerate(slices):
            if s is None:
                s = slice(None)
                slices[i] = s
            if isinstance(s, (int,slice)):
                tmp_label = self.cols[i][s]
                if self.colnames[i] == 'logtemp':
                    tmp_label = np.exp(tmp_label)
                if tmp_label.size > 1:
                    labels.append(tmp_label)
            else:
                raise Exception("The element {} is not a slice or integer or cannot be converted to a sliceable entity.".format(s))
        if len(labels) != 2:
            raise Exception("The slices should generate a 2 dimensional array. Verify your input slices.")
        t = Table(data=self.__getitem__(slices), names=labels[1].astype(str), copy=True)
        t.add_column(Column(data=labels[0]), index=0)
        t.pprint()

    @classmethod
    def ReadHDF5(cls, fln):
        try:
            import h5py
        except ImportError:
            raise Exception("h5py is needed for ReadHDF5")
        f = h5py.File(fln, 'r')

        flux = f['flux'].value

        meta = {}
        for key_attrs, val_attrs in f.attrs.iteritems():
            meta[key_attrs] = val_attrs
        colnames = meta.pop('colnames')
        name = meta.pop('name')
        description = meta.pop('description')

        cols = []
        grp = f['cols']
        for col in colnames:
            dset = grp[col]
            cols.append( Column(data=dset.value, name=col, meta=dict(dset.attrs.iteritems())) )
        cols = TableColumns(cols)

        f.close()
        return cls(data=flux, name=name, description=description, meta=meta, cols=cols)

    def SubGrid(self, *args):
        """
        Return a sub-grid of the atmosphere grid.

        Parameters
        ----------
        slices : slice
            Slice/sliceable object for each dimension of the atmosphere grid.

        Examples
        ----------
          Examples::
            This would extract atmo[:,1:4,:]
            new_atmo = atmo.SubGrid(slice(None),slice(1,4),slice(None))
        """
        assert len(args) == self.ndim, "The number of slices must match the dimension of the atmosphere grid."
        slices = []
        for s in args:
            if isinstance(s,int):
                slices.append(slice(s,s+1))
            else:
                slices.append(s)
        data = self.data[slices]
        cols = []
        for c,s in zip(self.cols,slices):
            cols.append( (c, np.atleast_1d(self.cols[c][s])) )
        return self.__class__(name=self.name, data=data, unit=self.unit, format=self.format, description=self.description, meta=self.meta, cols=cols)

    def Trim(self, colname, low=None, high=None):
        """
        Return a copy of the atmosphere grid whose 'colname' axis has been
        trimmed at the 'low' and 'high' values: low <= colvalues <= high.

        Parameters
        ----------
        colname : str
            Name of the column to trim the grid on.
        low : float
            Lowest value to cut from. If None, will use the minimum value.
        high: float
            Highest value to cut from. If None, will use the maximum value.

        Examples
        ----------
          Examples::
          The following would trim along the temperature axis and keep values
          between 4000 and 6000, inclusively.
            new_atmo = atmo.Trim('logtemp', low=np.log(4000.), high=np.log(6000.))
        """
        if colname not in self.colnames:
            raise Exception("The provided column name is not valid.")
        colind = self.colnames.index(colname)
        cols = self.cols.copy()
        if low is None:
            low = cols[colname].min()
        if high is None:
            high = cols[colname].max()
        inds = [slice(None)]*self.ndim
        inds[colind] = np.logical_and(self.cols[colname] >= low, self.cols[colname] <= high)
        cols[colname] = Column(data=cols[colname][inds[colind]], name=colname)
        data = self.data[inds].copy()
        meta = deepcopy(self.meta)
        return self.__class__(name=self.name, data=data, unit=self.unit, format=self.format, description=self.description, meta=meta, cols=cols)

    def WriteHDF5(self, fln, overwrite=False):
        try:
            import h5py
        except ImportError:
            raise Exception("h5py is needed for WriteHDF5")

        if os.path.exists(fln):
            if overwrite:
                os.remove(fln)
            else:
                raise IOError("File exists: {}".format(fln))

        f = h5py.File(fln, 'w')

        f.create_dataset(name='flux', data=self.data)
        f.attrs['colnames'] = self.cols.keys()
        f.attrs['name'] = self.name
        f.attrs['description'] = self.description

        for key_attrs, val_attrs in self.meta.iteritems():
            f.attrs[key_attrs] = val_attrs

        grp = f.create_group('cols')
        for key, val in self.cols.iteritems():
            dset = grp.create_dataset(name=key, data=val)
            if hasattr(val, 'meta'):
                for key_attrs, val_attrs in val.meta.iteritems():
                    dset.attrs[key_attrs] = val_attrs
        f.close()
Esempio n. 6
0
class AtmoGrid(Column):
    """
    Define the base atmosphere grid structure.

    AtmoGrid contains utilities to trim the grid, read/write to HDF5 format.

    Parameters
    ----------
    data : ndarray
        Grid of log(flux) values (e-base)
    name : str
        Keyword name of the atmosphere grid
    dtype : np.dtype compatible value
        Data type the flux grid
    shape : tuple or ()
        Dimensions of a single row element in the flux grid
    length : int or 0
        Number of row elements in the grid
    description : str or None
        Full description of the atmosphere grid
    unit : str or None
        Physical unit
    format : str or None or function or callable
        Format string for outputting column values.  This can be an
        "old-style" (``format % value``) or "new-style" (`str.format`)
        format specification string or a function or any callable object that
        accepts a single value and returns a string.
    meta : dict-like or None
        Meta-data associated with the atmosphere grid
    cols : OrderedDict-like, list of Columns, list of lists/tuples
        Full definition of the flux grid axes. This can be a list of entries
        ('colname', ndarray) with the ndarray corresponding to the axis values
        or a list of Columns containing this information.

    Examples
    --------
    A AtmoGrid can be created like this:

      Examples::

        logtemp = np.log(np.arange(3000.,10001.,250.))
        logg = np.arange(2.0, 5.6, 0.5)
        mu = np.arange(0.,1.01,0.02)
        logflux = np.random.normal(size=(logtemp.size,logg.size,mu.size))
        atmo = AtmoGrid(data=logflux, cols=[('logtemp',logtemp), ('logg',logg), ('mu',mu)])

    To read/save a file:

        atmo = AtmoGridPhot.ReadHDF5('vband.h5')
        atmo.WriteHDF5('vband_new.h5')

    Notes
    --------------
    Note that in principle the axis and data could be any format. However, we recommend using
    log(flux) and log(temperature) because the linear interpolation of such a grid would make
    more sense (say, from the blackbody $F \propto sigma T^4$).

    """
    def __new__(cls, data=None, name=None, dtype=None, shape=(), length=0, description=None, unit=None, format=None, meta=None, cols=None):

        self = super(AtmoGrid, cls).__new__(cls, data=data, name=name, dtype=dtype, shape=shape, length=length, description=description, unit=unit, format=format, meta=meta)

        if cols is None:
            self.cols = TableColumns([ Column(name=str(i), data=np.arange(self.shape[i], dtype=float)) for i in range(self.ndim) ])
        else:
            if len(cols) != self.ndim:
                raise ValueError('cols must contain a number of elements equal to the dimension of the data grid.')
            else:
                if isinstance(cols, TableColumns):
                    self.cols = cols
                else:
                    try:
                        self.cols = TableColumns([ Column(name=col[0], data=col[1]) if isinstance(col, (list,tuple)) else col for col in cols ])
                    except:
                        raise ValueError('Cannot make a TableColumns out of the provided cols parameter.')
            shape = tuple(col.size for col in self.cols.itervalues())
            if self.shape != shape:
                raise ValueError('The dimension of the data grid and the cols are not matching.')
        return self

    def __copy__(self):
        return self.copy(copy_data=False)

    def __deepcopy__(self):
        return self.copy(copy_data=True)

    def __getitem__(self, item):
        if isinstance(item, six.string_types):
            if item not in self.colnames:
                if 'log'+item in self.colnames:
                    return np.exp(self.cols['log'+item])
                elif 'log'+item[2:] in self.colnames:
                    return 10**(self.cols['log'+item[2:]])
                elif 'log10'+item in self.colnames:
                    return 10**(self.cols['log10'+item])
                else:
                    raise Exception('The provided column name is cannot be found.')
            else:
                return self.cols[item]
        else:
            #return super(AtmoGrid, self).__getitem__(item)
            return self.view(np.ndarray)[item]

    @property
    def colnames(self):
        return self.cols.keys()

    def copy(self, order='C', data=None, copy_data=True):
        """
        Copy of the instance. If ``data`` is supplied
        then a view (reference) of ``data`` is used, and ``copy_data`` is ignored.
        """
        if data is None:
            data = self.view(np.ndarray)
            if copy_data:
                data = data.copy(order)

        return self.__class__(name=self.name, data=data, unit=self.unit, format=self.format, description=self.description, meta=deepcopy(self.meta), cols=self.cols)

    def Fill_nan_old(self, axis=0, method='spline', bounds_error=False, fill_value=np.nan, k=1, s=1):
        """
        Fill the empty grid cells (marked as np.nan) with interpolated values
        along a given axis (i.e. interpolation is done in 1D).

        Parameters
        ----------
        axis : interpolate
            Axis along which the interpolation should be performed.
        method : str
            Interpolation method to use. Possible choices are 'spline' and
            'interp1d'.
            'spline' allows for the use of optional keywords k (the order) and
            s (the smoothing parameter). See scipy.interpolate.splrep.
            'interp1d' allows for the use of optional keywords bounds_error and
            fill_value. See scipy.interpolate.interp1d.
        bounds_error : bool
            Whether to raise an error when attempting to extrapolate out of
            bounds. Only works with 'interp1d'.
        fill_value : float
            Value to use when bounds_error is False. Only works with 'interp1d'.
        k : int
            Order of the spline to use. We recommend 1. Only works with
            'spline'.
        s : int
            Smoothing parameter for the spline. We recommend 0 (exact
            interpolation), or 1. Only works with 'spline'.

        Examples
        ----------
          Examples::
            atmo.Fill_nan(axis=0, method='interp1d', bounds_error=False, fill_value=np.nan)

        This would fill in the value that are not out of bound with a linear fit. Values
        out of bound would be np.nan.

            atmo.Fill_nan(axis=0, method='spline', k=1, s=0)

        This would produce exactly the same interpolation as above, except that values
        out of bound would be extrapolated.

        Notes
        ----------
        From our experience, it is recommended to first fill the values within
        the bounds using 'interp1d' with bounds_error=False and fill_value=np.nan,
        and then use 'spline' with k=1 and s=1 in order to extrapolate outside
        the bounds. To interpolate within the bounds, the temperature axis
        (i.e. 0) is generally best and more smooth, whereas the logg axis (i.e. 1)
        works better to extrapolate outside.


          Examples::
            atmo.Fill_nan(axis=0, method='interp1d', bounds_error=False, fill_value=np.nan)
            atmo.Fill_nan(axis=1, method='spline', k=1, s=1)
        """
        if method not in ['interp1d','spline']:
            raise Exception('Wrong method input! Must be either interp1d, spline or grid.')
        ndim = list(self.shape)
        ndim.pop(axis)
        inds_tmp = np.indices(ndim)
        inds = [ind.flatten() for ind in inds_tmp]
        niter = len(inds[0])
        inds.insert(axis, [slice(None)]*niter)
        print(inds)
        for ind in zip(*inds):
            col = self.__getitem__(ind)
            inds_good = np.isfinite(col)
            inds_bad = ~inds_good
            if np.any(inds_bad):
                if method == 'interp1d':
                    interpolator = scipy.interpolate.interp1d(self.cols[axis][inds_good], col[inds_good], assume_sorted=True, bounds_error=bounds_error, fill_value=fill_value)
                    col[inds_bad] = interpolator(self.cols[axis][inds_bad])
                elif method == 'spline':
                    tck = scipy.interpolate.splrep(self.cols[axis][inds_good], col[inds_good], k=k, s=s)
                    col[inds_bad] = scipy.interpolate.splev(self.cols[axis][inds_bad], tck)

    def Fill_nan(self, axis=0, inds_fill=None, method='spline', bounds_error=False, fill_value=np.nan, k=1, s=0, extrapolate=True):
        """
        Fill the empty grid cells (marked as np.nan) with interpolated values
        along a given axis (i.e. interpolation is done in 1D).

        Parameters
        ----------
        axis : interpolate
            Axis along which the interpolation should be performed.
        inds_fill : tuple(ndarray)
            Tuple/list containing the list of pixels to interpolate for.
        method : str
            Interpolation method to use. Possible choices are 'spline',
            'interp1d' and 'pchip'.
            'spline' allows for the use of optional keywords k (the order) and
            s (the smoothing parameter). See scipy.interpolate.splrep.
            'interp1d' allows for the use of optional keywords bounds_error and
            fill_value. See scipy.interpolate.interp1d.
            'pchip' allows to interpolate out of bound, or set NaNs.
        bounds_error : bool
            Whether to raise an error when attempting to extrapolate out of
            bounds.
            Only works with 'interp1d'.
        fill_value : float
            Value to use when bounds_error is False.
            Only works with 'interp1d'.
        k : int
            Order of the spline to use. We recommend 1.
            Only works with 'spline'.
        s : int
            Smoothing parameter for the spline. We recommend 0 (exact
            interpolation).
            Only works with 'spline'.
        extrapolate : bool
            Whether to extrapolate out of bound or set NaNs.
            Only works with 'pchip'.

        Examples
        ----------
          Examples::
            atmo.Fill_nan(axis=0, method='interp1d', bounds_error=False, fill_value=np.nan)

        This would fill in the value that are not out of bound with a linear fit. Values
        out of bound would be np.nan.

            atmo.Fill_nan(axis=0, method='spline', k=1, s=0)

        This would produce exactly the same interpolation as above, except that values
        out of bound would be extrapolated.

        Notes
        ----------
        From our experience, it is recommended to first fill the values within
        the bounds using 'interp1d' with bounds_error=False and fill_value=np.nan,
        and then use 'spline' with k=1 and s=1 in order to extrapolate outside
        the bounds. To interpolate within the bounds, the temperature axis
        (i.e. 0) is generally best and more smooth, whereas the logg axis (i.e. 1)
        works better to extrapolate outside.


          Examples::
            atmo.Fill_nan(axis=0, method='interp1d', bounds_error=False, fill_value=np.nan)
            atmo.Fill_nan(axis=1, method='spline', k=1, s=1)
        """
        if method not in ['interp1d','spline','pchip']:
            raise Exception('Wrong method input! Must be either interp1d, spline or grid.')
        if inds_fill is None:
            inds_fill = np.isnan(self.data).nonzero()
        else:
            assert len(inds_fill) == self.ndim, "The shape must be (ndim, nfill)."

        vals_fill = []
        for inds_fill_ in zip(*inds_fill):
            #print(inds_fill_)
            inds = list(inds_fill_)
            inds[axis] = slice(None)
            inds = tuple(inds)
            y = self.data[inds]
            x = self.cols[axis]
            x_interp = x[inds_fill_[axis]]
            #print(x)
            #print(y)
            #print(x_interp)
            inds_bad = np.isnan(y)
            inds_bad[x_interp] = True
            inds_good = ~inds_bad
            #print(inds_bad)
            #print(inds_good)
            if np.any(inds_good):
                #print(x[inds_good])
                #print(y[inds_good])
                #print(x_interp)
                if method == 'interp1d':
                    interpolator = scipy.interpolate.interp1d(x[inds_good], y[inds_good], assume_sorted=True, bounds_error=bounds_error, fill_value=fill_value)
                    y_interp = interpolator(x_interp)
                elif method == 'spline':
                    tck = scipy.interpolate.splrep(x[inds_good], y[inds_good], k=k, s=s)
                    y_interp = scipy.interpolate.splev(x_interp, tck)
                elif method == 'pchip':
                    interpolator = scipy.interpolate.PchipInterpolator(x[inds_good], y[inds_good], axis=0, extrapolate=extrapolate)
                    y_interp = interpolator(x_interp)
                #print(y_interp)
                vals_fill.append(y_interp)
            else:
                #print('y_interp -> nan')
                vals_fill.append(np.nan)
            #print('x_interp', x_interp)
            #print('y_interp', y_interp)
        self.data[inds_fill] = vals_fill

    def Getaxispos(self, colname, x):
        """
        Return the index and weight of the linear interpolation of the point
        along a given axis.

        Parameters
        ----------
        colname : str
            Name of the axis to interpolate from.
        x : float, ndarray
            Value to interpolate at.

        Examples
        ----------
          Examples::
            temp = Getaxispos('logtemp', np.log(3550.)
            logg = Getaxispos('logg', [4.11,4.13,4.02])
        """
        if isinstance(x, (list, tuple, np.ndarray)):
            return Utils.Series.Getaxispos_vector(self.cols[colname], x)
        else:
            return Utils.Series.Getaxispos_scalar(self.cols[colname], x)

    @property
    def IsFinite(self):
        return np.isfinite(self.data).astype(int)

    def Pprint(self, slices):
        """
        Print a 2-dimensional slice of the atmosphere grid for visualisation.

        Parameters
        ----------
        slices : list
            List of sliceable elements to extract the 2-dim slice to display.

        Examples
        ----------
          Examples::
            # Display the equivalent of atmo[:,:,4]
            atmo.Pprint([None,None,4])
            # Same as above but using fancier slice objects
            atmo.Pprint([slice(None),slice(None),4])
            # Display the equivalent of atmo[3:9,3,:]
            atmo.Pprint([slice(3,9),3,None])
        """
        slices = list(slices)
        labels = []
        for i,s in enumerate(slices):
            if s is None:
                s = slice(None)
                slices[i] = s
            if isinstance(s, (int,slice)):
                tmp_label = self.cols[i][s]
                if self.colnames[i] == 'logtemp':
                    tmp_label = np.exp(tmp_label)
                if tmp_label.size > 1:
                    labels.append(tmp_label)
            else:
                raise Exception("The element {} is not a slice or integer or cannot be converted to a sliceable entity.".format(s))
        if len(labels) != 2:
            raise Exception("The slices should generate a 2 dimensional array. Verify your input slices.")
        t = Table(data=self.__getitem__(slices), names=labels[1].astype(str), copy=True)
        t.add_column(Column(data=labels[0]), index=0)
        t.pprint()

    @classmethod
    def ReadHDF5(cls, fln):
        try:
            import h5py
        except ImportError:
            raise Exception("h5py is needed for ReadHDF5")
        f = h5py.File(fln, 'r')

        flux = np.ascontiguousarray(f['flux'].value, dtype=float)

        meta = {}
        for key_attrs, val_attrs in f.attrs.iteritems():
            meta[key_attrs] = val_attrs
        colnames = meta.pop('colnames')
        name = meta.pop('name')
        description = meta.pop('description')

        cols = []
        grp = f['cols']
        for col in colnames:
            dset = grp[col]
            cols.append( Column(data=np.ascontiguousarray(dset.value), name=col, meta=dict(dset.attrs.iteritems())) )
        cols = TableColumns(cols)

        f.close()
        return cls(data=flux, name=name, description=description, meta=meta, cols=cols)

    def SubGrid(self, *args):
        """
        Return a sub-grid of the atmosphere grid.

        Parameters
        ----------
        slices : slice
            Slice/sliceable object for each dimension of the atmosphere grid.

        Examples
        ----------
          Examples::
            This would extract atmo[:,1:4,:]
            new_atmo = atmo.SubGrid(slice(None),slice(1,4),slice(None))
        """
        assert len(args) == self.ndim, "The number of slices must match the dimension of the atmosphere grid."
        slices = []
        for s in args:
            if isinstance(s,int):
                slices.append(slice(s,s+1))
            else:
                slices.append(s)
        data = self.data[slices]
        cols = []
        for c,s in zip(self.cols,slices):
            cols.append( (c, np.atleast_1d(self.cols[c][s])) )
        return self.__class__(name=self.name, data=data, unit=self.unit, format=self.format, description=self.description, meta=self.meta, cols=cols)

    def Trim(self, colname, low=None, high=None):
        """
        Return a copy of the atmosphere grid whose 'colname' axis has been
        trimmed at the 'low' and 'high' values: low <= colvalues <= high.

        Parameters
        ----------
        colname : str
            Name of the column to trim the grid on.
        low : float
            Lowest value to cut from. If None, will use the minimum value.
        high: float
            Highest value to cut from. If None, will use the maximum value.

        Examples
        ----------
          Examples::
          The following would trim along the temperature axis and keep values
          between 4000 and 6000, inclusively.
            new_atmo = atmo.Trim('logtemp', low=np.log(4000.), high=np.log(6000.))
        """
        if colname not in self.colnames:
            raise Exception("The provided column name is not valid.")
        slices = [slice(None)]*self.ndim
        colind = self.colnames.index(colname)
        if low is None:
            low = cols[colname].min()
        if high is None:
            high = cols[colname].max()
        slices[colind] = np.logical_and(self.cols[colname] >= low, self.cols[colname] <= high)
        cols = []
        for c,s in zip(self.cols,slices):
            cols.append( (c, np.atleast_1d(self.cols[c][s])) )
        data = self.data[slices].copy()
        meta = deepcopy(self.meta)
        return self.__class__(name=self.name, data=data, unit=self.unit, format=self.format, description=self.description, meta=meta, cols=cols)

    def WriteHDF5(self, fln, overwrite=False):
        try:
            import h5py
        except ImportError:
            raise Exception("h5py is needed for WriteHDF5")

        if os.path.exists(fln):
            if overwrite:
                os.remove(fln)
            else:
                raise IOError("File exists: {}".format(fln))

        f = h5py.File(fln, 'w')

        f.create_dataset(name='flux', data=self.data)
        f.attrs['colnames'] = self.cols.keys()
        f.attrs['name'] = self.name
        f.attrs['description'] = self.description

        for key_attrs, val_attrs in self.meta.iteritems():
            f.attrs[key_attrs] = val_attrs

        grp = f.create_group('cols')
        for key, val in self.cols.iteritems():
            dset = grp.create_dataset(name=key, data=val)
            if hasattr(val, 'meta'):
                for key_attrs, val_attrs in val.meta.iteritems():
                    dset.attrs[key_attrs] = val_attrs
        f.close()
Esempio n. 7
0
def fit(planfile,
        model_name=None,
        spectrum_filenames=None,
        threads=8,
        clobber=True,
        from_filename=False,
        fit_velocity=False,
        chunk_size=1000,
        output_suffix=None,
        **kwargs):
    """
    Fit a series of spectra.
    """
    p = yanny.yanny(planfile, np=True)
    apred = p['apred_vers'].strip("'")
    apstar = p['apstar_vers'].strip("'")
    aspcap = p['aspcap_vers'].strip("'")
    results = p['results_vers'].strip("'")
    threads = int(getval(p, 'ncpus', '16'))
    cannon = getval(p, 'cannon_vers', 'cannon_aspcap')
    if model_name is None:
        model_name = getval(p, 'model_name', 'apogee-dr14-giants')
    if output_suffix is None:
        output_suffix = getval(p, 'output_suffix', 'result')
    logg = getrange(getval(p, 'logg', '-1 3.9'))
    teff = getrange(getval(p, 'teff', '3500 5500'))
    mh = getrange(getval(p, 'mh', '-3. 1.'))
    alpha = getrange(getval(p, 'alpha', '-0.5 1.'))

    root = os.environ[
        'APOGEE_REDUX'] + '/' + apred + '/' + apstar + '/' + aspcap + '/' + results + '/' + cannon + '/'
    model = tc.load_model(os.path.join(root, "{}.model".format(model_name)),
                          threads=threads)
    assert model.is_trained
    label_names = model.vectorizer.label_names
    mean_labels = Table.read(os.path.join(root,
                                          "{}.initial".format(model_name)),
                             format='ascii')['col0']
    sig_labels = Table.read(os.path.join(root,
                                         "{}.initial".format(model_name)),
                            format='ascii')['col2']
    #mean_labels = np.loadtxt(os.path.join(root, "{}.initial".format(model_name)))

    logger = logging.getLogger("AnniesLasso")

    # get allStar file for initial labels
    apl = apload.ApLoad(apred=apred,
                        apstar=apstar,
                        aspcap=aspcap,
                        results=results)
    allstar = apl.allStar()[1].data

    # loop over fields in planfile
    for field in p['ASPCAP']['field']:
        metadatas = []
        fluxes = []
        ivars = []
        output_filenames = []
        apogee_names = []
        failures = 0

        # get file names to fit
        try:
            paths = getfiles(apred, apstar, aspcap, results, cannon, field)
        except:
            return

        spectrum_filenames = []
        initial_labels = []
        apogee_ids = []
        for apogee_id, inpath, outpath in paths:
            # only take stars within certain parameter ranges
            print(apogee_id)
            #j=apselect.select(allstar,redid=apogee_id)[0]
            j = np.where(((allstar['REDUCTION_ID'] == apogee_id)
                          | (allstar['APOGEE_ID'] == apogee_id))
                         & (allstar['COMMISS'] == 0))[0]
            if (len(j) == 0):
                print('missing target', apogee_id)
            else:
                if len(j) > 1: j = j[0]
                if ((allstar['FPARAM'][j, 1] >= logg[0]) &
                    (allstar['FPARAM'][j, 1] <= logg[1]) &
                    (allstar['FPARAM'][j, 0] >= teff[0]) &
                    (allstar['FPARAM'][j, 0] <= teff[1]) &
                    (allstar['FPARAM'][j, 3] >= mh[0]) &
                    (allstar['FPARAM'][j, 3] <= mh[1]) &
                    (allstar['FPARAM'][j, 6] >= alpha[0]) &
                    (allstar['FPARAM'][j, 6] <= alpha[1])):
                    spectrum_filenames.append(outpath)
                    apogee_names.append(apogee_id)
                    #labels=[]
                    #for i,label in enumerate(label_names) :
                    #    if allstar[label][j][0] > -9 :
                    #        labels.append(allstar[label][j][0])
                    #    else :
                    #        labels.append(mean_labels[i])
                    #initial_labels.append(labels)

        if len(apogee_names) == 0: return

        #initial_labels=np.array(initial_labels)
        initial_labels = mean_labels
        # MAGIC HACK
        delete_meta_keys = ("fjac", )  # To save space...

        #output_suffix = kwargs.get("output_suffix", None)
        #output_suffix = "result" if output_suffix is None else str(output_suffix)
        summary_file = root + field + '/cannonField-' + os.path.basename(
            field) + '-' + output_suffix + '.fits'
        N = len(spectrum_filenames)
        for i, names in enumerate(zip(apogee_names, spectrum_filenames)):
            apogee_id = names[0]
            filename = names[1]
            logger.info("At spectrum {0}/{1}: {2}".format(i + 1, N, filename))

            basename, _ = os.path.splitext(filename)
            output_filename = "-".join([basename, output_suffix]) + ".pkl"

            if os.path.exists(output_filename) and not clobber:
                logger.info("Output filename {} already exists and not clobbering."\
                    .format(output_filename))
                continue

            try:
                with open(filename, "rb") as fp:
                    metadata, data = pickle.load(fp)
                    metadatas.append(metadata)
                    flux, ivar = data
                    fluxes.append(flux)
                    ivars.append(ivar)

                output_filenames.append(output_filename)
                apogee_ids.append(apogee_id)

            except:
                logger.exception("Error occurred loading {}".format(filename))
                failures += 1

            else:
                if len(output_filenames) >= chunk_size:

                    results, covs, metas = model.fit(
                        fluxes,
                        ivars,
                        initial_labels=initial_labels,
                        model_redshift=fit_velocity,
                        full_output=True)

                    for result, cov, meta, output_filename \
                    in zip(results, covs, metas, output_filenames):

                        for key in delete_meta_keys:
                            if key in meta:
                                del meta[key]

                        with open(output_filename, "wb") as fp:
                            pickle.dump((result, cov, meta), fp,
                                        2)  # For legacy.
                        logger.info(
                            "Saved output to {}".format(output_filename))

                    del output_filenames[0:], fluxes[0:], ivars[0:]

        if len(output_filenames) > 0:

            results, covs, metas = model.fit(fluxes,
                                             ivars,
                                             initial_labels=initial_labels,
                                             model_redshift=fit_velocity,
                                             full_output=True)

            # Create an ordered dictionary of lists for all the data.
            data_dict = OrderedDict([("FILENAME", [])])
            data_dict['APOGEE_ID'] = []
            data_dict['LOCATION_ID'] = []
            data_dict['FIELD'] = []
            for label_name in label_names:
                data_dict[label_name] = []
            for label_name in label_names:
                data_dict["{}_RAWERR".format(label_name)] = []
            for label_name in label_names:
                data_dict["{}_ERR".format(label_name)] = []
            #data_dict["COV"] = []
            #meta_keys=metas[0].keys()
            meta_keys = ['chi_sq', 'r_chi_sq', 'model_flux']
            for key in meta_keys:
                data_dict[key] = []
            data_dict['flux'] = []
            data_dict['ivar'] = []

            # loop over spectra, output individual files, and accumulate for summary file
            for result, cov, meta, output_filename,apogee_id,metadata,flux,ivar \
            in zip(results, covs, metas, output_filenames, apogee_ids,metadatas,fluxes,ivars):

                if np.isfinite(result).all():
                    outlist = [
                        os.path.basename(output_filename), apogee_id,
                        metadata['LOCATION_ID'], metadata['FIELD']
                    ] + result.tolist()
                    try:
                        rawerr = np.diag(cov)**0.5
                        outlist.extend(rawerr)
                    except:
                        pdb.set_trace()
                    outlist.extend(np.max([rawerr, sig_labels], axis=0))
                    #outlist.append(cov.tolist())
                    for key in delete_meta_keys:
                        if key in meta:
                            del meta[key]
                    #outlist += [meta.get(k, v) for k, v in meta.items()]
                    outlist += [meta.get(k) for k in meta_keys]
                    outlist.append(flux)
                    outlist.append(ivar)
                    for key, value in zip(data_dict.keys(), outlist):
                        data_dict[key].append(value)

                    # save to pkl file?
                    #with open(output_filename, "wb") as fp:
                    #    pickle.dump((result, cov, meta), fp, 2) # For legacy.
                    #logger.info("Saved output to {}".format(output_filename))

                    # save to FITS cannonStar file
                    hdr = fits.Header()
                    hdr['HISTORY'] = 'IDLWRAP_VERSION: ' + subprocess.check_output(
                        'idlwrap_version').strip('\n')
                    hdr['OBJ'] = apogee_id
                    hdr['LOCID'] = metadata['LOCATION_ID']
                    hdr['FIELD'] = metadata['FIELD']
                    hdr['CHI2'] = meta.get('r_chi_sq')
                    for i, label_name in enumerate(label_names):
                        hdr[label_name] = result[i]
                    hdulist = fits.HDUList(fits.PrimaryHDU(header=hdr))
                    hdr = fits.Header()
                    hdr['OBSERVER'] = 'Edwin Hubble'
                    hdr['CRVAL1'] = 4.179e0
                    hdr['CDELT1'] = 6.e-6
                    hdr['CRPIX1'] = 1
                    hdr['CTYPE1'] = 'LOG-LINEAR'
                    hdr['DC-FLAG'] = 1
                    hdulist.append(fits.ImageHDU(flux, header=hdr))
                    hdulist.append(
                        fits.ImageHDU(1. / np.sqrt(ivar), header=hdr))
                    hdulist.append(
                        fits.ImageHDU(meta.get('model_flux'), header=hdr))
                    hdulist.writeto(output_filename.replace(
                        '-result', '').replace('.pkl', '.fits'),
                                    overwrite=True)

            del output_filenames[0:], fluxes[0:], ivars[0:]

        logger.info("Number of failures: {}".format(failures))
        logger.info("Number of successes: {}".format(N - failures))
        table = Table(TableColumns(data_dict))
        table.write(summary_file.replace('-result', ''), overwrite=clobber)
        logger.info("Written to {}".format(summary_file))

    return None