示例#1
0
 def is_numlike(x):
     """
     The Matplotlib datalim, autoscaling, locators etc work with
     scalars which are the units converted to floats given the
     current unit.  The converter may be passed these floats, or
     arrays of them, even when units are set.
     """
     if iterable(x):
         for thisx in x:
             return is_numlike(thisx)
     else:
         return is_numlike(x)
示例#2
0
 def is_numlike(x):
     """
     The Matplotlib datalim, autoscaling, locators etc work with
     scalars which are the units converted to floats given the
     current unit.  The converter may be passed these floats, or
     arrays of them, even when units are set.
     """
     if iterable(x):
         for thisx in x:
             return is_numlike(thisx)
     else:
         return is_numlike(x)
示例#3
0
 def is_numlike(x):
     """
     The matplotlib datalim, autoscaling, locators etc work with
     scalars which are the units converted to floats given the
     current unit.  The converter may be passed these floats, or
     arrays of them, even when units are set.  Derived conversion
     interfaces may opt to pass plain-ol unitless numbers through
     the conversion interface and this is a helper function for
     them.
     """
     if iterable(x):
         for thisx in x:
             return is_numlike(thisx)
     else:
         return is_numlike(x)
示例#4
0
    def _calculate_global(self, data):
        # Calculate breaks if x is not categorical
        binwidth = self.params['binwidth']
        self.breaks = self.params['breaks']
        right = self.params['right']
        x = data['x'].values

        # For categorical data we set labels and x-vals
        if is_categorical(x):
            labels = self.params['labels']
            if labels == None:
                labels = sorted(set(x))
            self.labels = labels
            self.length = len(self.labels)

        # For non-categoriacal data we set breaks
        if not (is_categorical(x) or self.breaks):
            # Check that x is numerical
            if not cbook.is_numlike(x[0]):
                raise GgplotError("Cannot recognise the type of x")
            if binwidth is None:
                _bin_count = 30
                self._print_warning(_MSG_BINWIDTH)
            else:
                _bin_count = int(np.ceil(np.ptp(x))) / binwidth
            _, self.breaks = pd.cut(x, bins=_bin_count, labels=False,
                                        right=right, retbins=True)
            self.length = len(self.breaks)
示例#5
0
文件: vmc.py 项目: EPFL-LQM/gpvmc
def ScanDir(folder='.',keys=[],pattern=r".*\.h5",return_dict=False,req={}):
    out={}
    for f in os.listdir(folder):
        if re.match(pattern,f) is not None:
            try:
                isreq=(len(req)==0)
                if not isreq:
                    isreq=True
                    fd=GetAttr("{0}/{1}".format(folder,f))
                    for k in req.keys():
                        try:
                            if is_numlike(req[k]):
                                isreq=isreq and (abs(req[k]-fd[k])<1e-9)
                            else:
                                isreq=isreq and (req[k]==fd[k])
                        except KeyError:
                            isreq=False
                if isreq:
                    out[folder+'/'+f]=dict(GetAttr("{0}/{1}".format(folder,f)))
                    s=f
                    if len(keys):
                        s="{0}: ".format(f)
                        if keys=='*':
                            keys=out[folder+'/'+f].keys()
                        for k in keys:
                            try:
                                s="{0} {1}:{2} /".format(s,k,out[folder+'/'+f][k])
                            except KeyError:
                                s="{0} None /".format(s)
                    print(s)
            except IOError:
                print('Could not open \"'+f+'\".')
    if return_dict:
        return out
示例#6
0
 def getname_val(identifier):
     'return the name and column data for identifier'
     if is_string_like(identifier):
         return identifier, r[identifier]
     elif is_numlike(identifier):
         name = r.dtype.names[int(identifier)]
         return name, r[name]
     else:
         raise TypeError('identifier must be a string or integer')
示例#7
0
文件: layer.py 项目: jwhendy/plotnine
def is_known_scalar(value):
    """
    Return True if value is a type we expect in a dataframe
    """
    def _is_datetime_or_timedelta(value):
        # Using pandas.Series helps catch python, numpy and pandas
        # versions of these types
        return pd.Series(value).dtype.kind in ('M', 'm')

    return not cbook.iterable(value) and (cbook.is_numlike(value) or
                                          _is_datetime_or_timedelta(value))
示例#8
0
def is_known_scalar(value):
    """
    Return True if value is a type we expect in a dataframe
    """
    def _is_datetime_or_timedelta(value):
        # Using pandas.Series helps catch python, numpy and pandas
        # versions of these types
        return pd.Series(value).dtype.kind in ('M', 'm')

    return not cbook.iterable(value) and (cbook.is_numlike(value)
                                          or _is_datetime_or_timedelta(value))
示例#9
0
 def getname_val(identifier):
     'return the name and column data for identifier'
     if is_string_like(identifier):
         print "Identifier " + identifier + " is a string"
         col_name = identifier.strip().lower().replace(' ', '_')
         col_name = ''.join([c for c in col_name if c not in delete])
         return identifier, r[col_name]
     elif is_numlike(identifier):
         name = r.dtype.names[int(identifier)]
         return name, r[name]
     else:
         raise TypeError('identifier must be a string or integer')
示例#10
0
文件: dataset.py 项目: awacha/sastool
 def _convert_numcompatible(self, c):
     """Convert c to a form usable by arithmetic operations"""
     #the compatible dataset to be returned, initialize it to zeros.
     comp = {'x':np.zeros_like(self._x),
           'y':np.zeros_like(self._x),
           'dy':np.zeros_like(self._x),
           'dx':np.zeros_like(self._x)}
     # if c is a DataSet:
     if isinstance(c, AliasedVectorAttributes):
         if self.shape() != c.shape(): # they are of incompatible size, fail.
             raise ValueError('incompatible length')
         # if the size of them is compatible, check if the abscissae are
         # compatible.
         xtol = min(self._xtolerance, c._xtolerance) # use the strictest
         if max(np.abs(self._x - c._x)) < xtol:
             try:
                 comp['x'] = c._x
                 comp['y'] = c._y
                 comp['dy'] = c._dy
                 comp['dx'] = c._dx
             except AttributeError:
                 pass # this is not a fatal error
         else:
             raise ValueError('incompatible abscissae')
     elif isinstance(c, ErrorValue):
         comp['x'] = self._x
         comp['y'] += c.val
         comp['dy'] += c.err
     elif isinstance(c, tuple): # if c is a tuple
         try:
             #the fields of comp were initialized to zero np arrays!
             comp['x'] += c[0]
             comp['y'] += c[1]
             comp['dy'] += c[2]
             comp['dx'] += c[3]
         except IndexError:
             pass # this is not fatal either
     else:
         if is_numlike(c):
             try:
                 comp['x'] = self._x
                 comp['y'] += c # leave this job to numpy.ndarray.__iadd__()
             except:
                 raise DataSetError('Incompatible size')
         else:
             raise DataSetError('Incompatible type')
     return comp
示例#11
0
def from_any(size, fraction_ref=None):
    """
    Creates Fixed unit when the first argument is a float, or a
    Fraction unit if that is a string that ends with %. The second
    argument is only meaningful when Fraction unit is created.

      >>> a = Size.from_any(1.2) # => Size.Fixed(1.2)
      >>> Size.from_any("50%", a) # => Size.Fraction(0.5, a)

    """
    if cbook.is_numlike(size):
        return Fixed(size)
    elif cbook.is_string_like(size):
        if size[-1] == "%":
            return Fraction(float(size[:-1])/100., fraction_ref)

    raise ValueError("Unknown format")
def from_any(size, fraction_ref=None):
    """
    Creates Fixed unit when the first argument is a float, or a
    Fraction unit if that is a string that ends with %. The second
    argument is only meaningful when Fraction unit is created.::

      >>> a = Size.from_any(1.2) # => Size.Fixed(1.2)
      >>> Size.from_any("50%", a) # => Size.Fraction(0.5, a)

    """
    if cbook.is_numlike(size):
        return Fixed(size)
    elif isinstance(size, six.string_types):
        if size[-1] == "%":
            return Fraction(float(size[:-1]) / 100, fraction_ref)

    raise ValueError("Unknown format")
示例#13
0
    def _calculate_global(self, data):
        # Calculate breaks if x is not categorical
        binwidth = self.params['binwidth']
        self.breaks = self.params['breaks']
        right = self.params['right']
        x = data['x'].values

        # For categorical data we set labels and x-vals
        if is_categorical(x):
            labels = self.params['labels']
            if labels == None:
                labels = sorted(set(x))
            self.labels = labels
            self.length = len(self.labels)

        # For non-categoriacal data we set breaks
        if not (is_categorical(x) or self.breaks):
            # Check that x is numerical
            if len(x) > 0 and isinstance(x[0], datetime.date):

                def convert(d):
                    d = datetime.datetime.combine(d,
                                                  datetime.datetime.min.time())
                    return time.mktime(d.timetuple())

                x = [convert(d) for d in x]
            elif len(x) > 0 and isinstance(x[0], datetime.datetime):
                x = [time.mktime(d.timetuple()) for d in x]
            elif len(x) > 0 and isinstance(x[0], datetime.time):
                raise GgplotError("Cannot recognise the type of x")
            elif not cbook.is_numlike(x[0]):
                raise GgplotError("Cannot recognise the type of x")
            if binwidth is None:
                _bin_count = 30
                self._print_warning(_MSG_BINWIDTH)
            else:
                _bin_count = int(np.ceil(np.ptp(x))) / binwidth
            _, self.breaks = pd.cut(x,
                                    bins=_bin_count,
                                    labels=False,
                                    right=right,
                                    retbins=True)
            self.length = len(self.breaks)
示例#14
0
文件: dataset.py 项目: awacha/sastool
 def _convert_numcompatible(self, c):
     """Convert c to a form usable by arithmetic operations"""
     #the compatible dataset to be returned, initialize it to zeros.
     if isinstance(c, MatrixAttrMixin):
         if self.shape() != c.shape(): # they are of incompatible size, fail.
             raise ValueError('incompatible shape')
         # if the size of them is compatible, check if the abscissae are
         # compatible.
         return (c._A, c._dA)
     elif isinstance(c, ErrorValue):
         return (c.val, c.err)
     elif isinstance(c, tuple): # if c is a tuple
         try:
             return (c[0], c[1])
         except IndexError:
             return (c[0], 0)
     else:
         if is_numlike(c):
             return c, 0
     raise DataSetError('Incompatible type')
示例#15
0
def ScanDir(folder='.', keys=[], pattern=r".*\.h5", return_dict=False, req={}):
    out = {}
    for f in os.listdir(folder):
        if re.match(pattern, f) is not None:
            try:
                isreq = (len(req) == 0)
                if not isreq:
                    isreq = True
                    fd = GetAttr("{0}/{1}".format(folder, f))
                    for k in req.keys():
                        try:
                            if is_numlike(req[k]):
                                isreq = isreq and (abs(req[k] - fd[k]) < 1e-9)
                            else:
                                isreq = isreq and (req[k] == fd[k])
                        except KeyError:
                            isreq = False
                if isreq:
                    out[folder + '/' + f] = dict(
                        GetAttr("{0}/{1}".format(folder, f)))
                    s = f
                    if len(keys):
                        s = "{0}: ".format(f)
                        if keys == '*':
                            keys = out[folder + '/' + f].keys()
                        for k in keys:
                            try:
                                s = "{0} {1}:{2} /".format(
                                    s, k, out[folder + '/' + f][k])
                            except KeyError:
                                s = "{0} None /".format(s)
                    print(s)
            except IOError:
                print('Could not open \"' + f + '\".')
    if return_dict:
        return out
def hist(self,
         x,
         bins=10,
         range=None,
         normed=False,
         weights=None,
         cumulative=False,
         bottom=None,
         histtype='bar',
         align='mid',
         orientation='vertical',
         rwidth=None,
         log=False,
         color=None,
         label=None,
         **kwargs):
    """
    call signature::

      hist(x, bins=10, range=None, normed=False, cumulative=False,
           bottom=None, histtype='bar', align='mid',
           orientation='vertical', rwidth=None, log=False, **kwargs)

    Compute and draw the histogram of *x*. The return value is a
    tuple (*n*, *bins*, *patches*) or ([*n0*, *n1*, ...], *bins*,
    [*patches0*, *patches1*,...]) if the input contains multiple
    data.

    Multiple data can be provided via *x* as a list of datasets
    of potentially different length ([*x0*, *x1*, ...]), or as
    a 2-D ndarray in which each column is a dataset.  Note that
    the ndarray form is transposed relative to the list form.

    Masked arrays are not supported at present.

    Keyword arguments:

      *bins*:
        Either an integer number of bins or a sequence giving the
        bins.  If *bins* is an integer, *bins* + 1 bin edges
        will be returned, consistent with :func:`numpy.histogram`
        for numpy version >= 1.3, and with the *new* = True argument
        in earlier versions.
        Unequally spaced bins are supported if *bins* is a sequence.

      *range*:
        The lower and upper range of the bins. Lower and upper outliers
        are ignored. If not provided, *range* is (x.min(), x.max()).
        Range has no effect if *bins* is a sequence.

        If *bins* is a sequence or *range* is specified, autoscaling
        is based on the specified bin range instead of the
        range of x.

      *normed*:
        If *True*, the first element of the return tuple will
        be the counts normalized to form a probability density, i.e.,
        ``n/(len(x)*dbin)``.  In a probability density, the integral of
        the histogram should be 1; you can verify that with a
        trapezoidal integration of the probability density function::

          pdf, bins, patches = ax.hist(...)
          print np.sum(pdf * np.diff(bins))

        .. Note:: Until numpy release 1.5, the underlying numpy
                  histogram function was incorrect with *normed*=*True*
                  if bin sizes were unequal.  MPL inherited that
                  error.  It is now corrected within MPL when using
                  earlier numpy versions

      *weights*
        An array of weights, of the same shape as *x*.  Each value in
        *x* only contributes its associated weight towards the bin
        count (instead of 1).  If *normed* is True, the weights are
        normalized, so that the integral of the density over the range
        remains 1.

      *cumulative*:
        If *True*, then a histogram is computed where each bin
        gives the counts in that bin plus all bins for smaller values.
        The last bin gives the total number of datapoints.  If *normed*
        is also *True* then the histogram is normalized such that the
        last bin equals 1. If *cumulative* evaluates to less than 0
        (e.g. -1), the direction of accumulation is reversed.  In this
        case, if *normed* is also *True*, then the histogram is normalized
        such that the first bin equals 1.

      *histtype*: [ 'bar' | 'barstacked' | 'step' | 'stepfilled' ]
        The type of histogram to draw.

          - 'bar' is a traditional bar-type histogram.  If multiple data
            are given the bars are aranged side by side.

          - 'barstacked' is a bar-type histogram where multiple
            data are stacked on top of each other.

          - 'step' generates a lineplot that is by default
            unfilled.

          - 'stepfilled' generates a lineplot that is by default
            filled.

      *align*: ['left' | 'mid' | 'right' ]
        Controls how the histogram is plotted.

          - 'left': bars are centered on the left bin edges.

          - 'mid': bars are centered between the bin edges.

          - 'right': bars are centered on the right bin edges.

      *orientation*: [ 'horizontal' | 'vertical' ]
        If 'horizontal', :func:`~matplotlib.pyplot.barh` will be
        used for bar-type histograms and the *bottom* kwarg will be
        the left edges.

      *rwidth*:
        The relative width of the bars as a fraction of the bin
        width.  If *None*, automatically compute the width. Ignored
        if *histtype* = 'step' or 'stepfilled'.

      *log*:
        If *True*, the histogram axis will be set to a log scale.
        If *log* is *True* and *x* is a 1D array, empty bins will
        be filtered out and only the non-empty (*n*, *bins*,
        *patches*) will be returned.

      *color*:
        Color spec or sequence of color specs, one per
        dataset.  Default (*None*) uses the standard line
        color sequence.

      *label*:
        String, or sequence of strings to match multiple
        datasets.  Bar charts yield multiple patches per
        dataset, but only the first gets the label, so
        that the legend command will work as expected::

            ax.hist(10+2*np.random.randn(1000), label='men')
            ax.hist(12+3*np.random.randn(1000), label='women', alpha=0.5)
            ax.legend()

    kwargs are used to update the properties of the
    :class:`~matplotlib.patches.Patch` instances returned by *hist*:

    %(Patch)s

    **Example:**

    .. plot:: mpl_examples/pylab_examples/histogram_demo.py
    """
    if not self._hold: self.cla()

    # NOTE: the range keyword overwrites the built-in func range !!!
    #       needs to be fixed in numpy                           !!!

    # Validate string inputs here so we don't have to clutter
    # subsequent code.
    if histtype not in ['bar', 'barstacked', 'step', 'stepfilled']:
        raise ValueError("histtype %s is not recognized" % histtype)

    if align not in ['left', 'mid', 'right']:
        raise ValueError("align kwarg %s is not recognized" % align)

    if orientation not in ['horizontal', 'vertical']:
        raise ValueError("orientation kwarg %s is not recognized" %
                         orientation)

    if kwargs.get('width') is not None:
        raise DeprecationWarning(
            'hist now uses the rwidth to give relative width '
            'and not absolute width')

    # Massage 'x' for processing.
    # NOTE: Be sure any changes here is also done below to 'weights'
    if isinstance(x, np.ndarray) or not iterable(x[0]):
        # TODO: support masked arrays;
        x = np.asarray(x)
        if x.ndim == 2:
            x = x.T  # 2-D input with columns as datasets; switch to rows
        elif x.ndim == 1:
            x = x.reshape(1, x.shape[0])  # new view, single row
        else:
            raise ValueError("x must be 1D or 2D")
        if x.shape[1] < x.shape[0]:
            warnings.warn('2D hist input should be nsamples x nvariables;\n '
                          'this looks transposed (shape is %d x %d)' %
                          x.shape[::-1])
    else:
        # multiple hist with data of different length
        x = [np.array(xi) for xi in x]

    nx = len(x)  # number of datasets

    if color is None:
        color = [self._get_lines.color_cycle.next() for i in xrange(nx)]
    else:
        color = mcolors.colorConverter.to_rgba_array(color)
        if len(color) != nx:
            raise ValueError("color kwarg must have one color per dataset")

    # We need to do to 'weights' what was done to 'x'
    if weights is not None:
        if isinstance(weights, np.ndarray) or not iterable(weights[0]):
            w = np.array(weights)
            if w.ndim == 2:
                w = w.T
            elif w.ndim == 1:
                w.shape = (1, w.shape[0])
            else:
                raise ValueError("weights must be 1D or 2D")
        else:
            w = [np.array(wi) for wi in weights]

        if len(w) != nx:
            raise ValueError('weights should have the same shape as x')
        for i in xrange(nx):
            if len(w[i]) != len(x[i]):
                raise ValueError('weights should have the same shape as x')
    else:
        w = [None] * nx

    # Save autoscale state for later restoration; turn autoscaling
    # off so we can do it all a single time at the end, instead
    # of having it done by bar or fill and then having to be redone.
    _saved_autoscalex = self.get_autoscalex_on()
    _saved_autoscaley = self.get_autoscaley_on()
    self.set_autoscalex_on(False)
    self.set_autoscaley_on(False)

    # Save the datalimits for the same reason:
    _saved_bounds = self.dataLim.bounds

    # Check whether bins or range are given explicitly. In that
    # case use those values for autoscaling.
    binsgiven = (cbook.iterable(bins) or range != None)

    # If bins are not specified either explicitly or via range,
    # we need to figure out the range required for all datasets,
    # and supply that to np.histogram.
    if not binsgiven:
        xmin = np.inf
        xmax = -np.inf
        for xi in x:
            xmin = min(xmin, xi.min())
            xmax = max(xmax, xi.max())
        range = (xmin, xmax)

    #hist_kwargs = dict(range=range, normed=bool(normed))
    # We will handle the normed kwarg within mpl until we
    # get to the point of requiring numpy >= 1.5.
    hist_kwargs = dict(range=range)
    if np.__version__ < "1.3":  # version 1.1 and 1.2
        hist_kwargs['new'] = True

    n = []
    for i in xrange(nx):
        # this will automatically overwrite bins,
        # so that each histogram uses the same bins
        m, bins = np.histogram(x[i], bins, weights=w[i], **hist_kwargs)
        if normed:
            db = np.diff(bins)
            m = (m.astype(float) / db) / m.sum()
        n.append(m)
    if normed and db.std() > 0.01 * db.mean():
        warnings.warn("""
        This release fixes a normalization bug in the NumPy histogram
        function prior to version 1.5, occuring with non-uniform
        bin widths. The returned and plotted value is now a density:
            n / (N * bin width),
        where n is the bin count and N the total number of points.
        """)

    if cumulative:
        slc = slice(None)
        if cbook.is_numlike(cumulative) and cumulative < 0:
            slc = slice(None, None, -1)

        if normed:
            n = [(m * np.diff(bins))[slc].cumsum()[slc] for m in n]
        else:
            n = [m[slc].cumsum()[slc] for m in n]

    patches = []

    if histtype.startswith('bar'):
        totwidth = np.diff(bins)

        if rwidth is not None:
            dr = min(1.0, max(0.0, rwidth))
        elif len(n) > 1:
            dr = 0.8
        else:
            dr = 1.0

        if histtype == 'bar':
            width = dr * totwidth / nx
            dw = width

            if nx > 1:
                boffset = -0.5 * dr * totwidth * (1.0 - 1.0 / nx)
            else:
                boffset = 0.0
            stacked = False
        elif histtype == 'barstacked':
            width = dr * totwidth
            boffset, dw = 0.0, 0.0
            stacked = True

        if align == 'mid' or align == 'edge':
            boffset += 0.5 * totwidth
        elif align == 'right':
            boffset += totwidth

        if orientation == 'horizontal':
            _barfunc = self.barh
        else:  # orientation == 'vertical'
            _barfunc = self.bar

        for m, c in zip(n, color):
            patch = _barfunc(bins[:-1] + boffset,
                             m,
                             width,
                             bottom,
                             align='center',
                             log=log,
                             color=c)
            patches.append(patch)
            if stacked:
                if bottom is None:
                    bottom = 0.0
                bottom += m
            boffset += dw

    elif histtype.startswith('step'):
        x = np.zeros(2 * len(bins), np.float)
        y = np.zeros(2 * len(bins), np.float)

        x[0::2], x[1::2] = bins, bins

        # FIX FIX FIX
        # This is the only real change.
        # minimum = min(bins)
        if log is True:
            minimum = 1.0
        elif log:
            minimum = float(log)
        else:
            minimum = 0.0
        # FIX FIX FIX end

        if align == 'left' or align == 'center':
            x -= 0.5 * (bins[1] - bins[0])
        elif align == 'right':
            x += 0.5 * (bins[1] - bins[0])

        if log:
            y[0], y[-1] = minimum, minimum
            if orientation == 'horizontal':
                self.set_xscale('log')
            else:  # orientation == 'vertical'
                self.set_yscale('log')

        fill = (histtype == 'stepfilled')

        for m, c in zip(n, color):
            y[1:-1:2], y[2::2] = m, m
            if log:
                y[y < minimum] = minimum
            if orientation == 'horizontal':
                x, y = y, x

            if fill:
                patches.append(self.fill(x, y, closed=False, facecolor=c))
            else:
                patches.append(
                    self.fill(x, y, closed=False, edgecolor=c, fill=False))

        # adopted from adjust_x/ylim part of the bar method
        if orientation == 'horizontal':
            xmin0 = max(_saved_bounds[0] * 0.9, minimum)
            xmax = self.dataLim.intervalx[1]
            for m in n:
                xmin = np.amin(m[m != 0])  # filter out the 0 height bins
            xmin = max(xmin * 0.9, minimum)
            xmin = min(xmin0, xmin)
            self.dataLim.intervalx = (xmin, xmax)
        elif orientation == 'vertical':
            ymin0 = max(_saved_bounds[1] * 0.9, minimum)
            ymax = self.dataLim.intervaly[1]
            for m in n:
                ymin = np.amin(m[m != 0])  # filter out the 0 height bins
            ymin = max(ymin * 0.9, minimum)
            ymin = min(ymin0, ymin)
            self.dataLim.intervaly = (ymin, ymax)

    if label is None:
        labels = ['_nolegend_']
    elif is_string_like(label):
        labels = [label]
    elif is_sequence_of_strings(label):
        labels = list(label)
    else:
        raise ValueError(
            'invalid label: must be string or sequence of strings')
    if len(labels) < nx:
        labels += ['_nolegend_'] * (nx - len(labels))

    for (patch, lbl) in zip(patches, labels):
        for p in patch:
            p.update(kwargs)
            p.set_label(lbl)
            lbl = '_nolegend_'

    if binsgiven:
        if orientation == 'vertical':
            self.update_datalim([(bins[0], 0), (bins[-1], 0)], updatey=False)
        else:
            self.update_datalim([(0, bins[0]), (0, bins[-1])], updatex=False)

    self.set_autoscalex_on(_saved_autoscalex)
    self.set_autoscaley_on(_saved_autoscaley)
    self.autoscale_view()

    if nx == 1:
        return n[0], bins, cbook.silent_list('Patch', patches[0])
    else:
        return n, bins, cbook.silent_list('Lists of Patches', patches)
示例#17
0
def draw_networkx_edges(G, pos,
                        edgelist=None,
                        width=1.0,
                        edge_color='k',
                        style='solid',
                        alpha=None,
                        edge_cmap=None,
                        edge_vmin=None,
                        edge_vmax=None, 
                        ax=None,
                        arrows=True,
                        **kwds):
    """Draw the edges of the graph G

    This draws only the edges of the graph G.

    pos is a dictionary keyed by vertex with a two-tuple
    of x-y positions as the value.
    See networkx.layout for functions that compute node positions.

    edgelist is an optional list of the edges in G to be drawn.
    If provided, only the edges in edgelist will be drawn. 

    edgecolor can be a list of matplotlib color letters such as 'k' or
    'b' that lists the color of each edge; the list must be ordered in
    the same way as the edge list. Alternatively, this list can contain
    numbers and those number are mapped to a color scale using the color
    map edge_cmap.  Finally, it can also be a list of (r,g,b) or (r,g,b,a)
    tuples, in which case these will be used directly to color the edges.  If
    the latter mode is used, you should not provide a value for alpha, as it
    would be applied globally to all lines.
    
    For directed graphs, "arrows" (actually just thicker stubs) are drawn
    at the head end.  Arrows can be turned off with keyword arrows=False.

    See draw_networkx for the list of other optional parameters.

    """
    try:
        import matplotlib
        import matplotlib.pylab as pylab
        import matplotlib.cbook as cb
        from matplotlib.colors import colorConverter,Colormap
        from matplotlib.collections import LineCollection
        import numpy
    except ImportError:
        raise ImportError, "Matplotlib required for draw()"
    except RuntimeError:
        pass # unable to open display

    if ax is None:
        ax=pylab.gca()

    if edgelist is None:
        edgelist=G.edges()

    if not edgelist or len(edgelist)==0: # no edges!
        return None

    # set edge positions
    edge_pos=numpy.asarray([(pos[e[0]],pos[e[1]]) for e in edgelist])
    
    if not cb.iterable(width):
        lw = (width,)
    else:
        lw = width

    if not cb.is_string_like(edge_color) \
           and cb.iterable(edge_color) \
           and len(edge_color)==len(edge_pos):
        if numpy.alltrue([cb.is_string_like(c) 
                         for c in edge_color]):
            # (should check ALL elements)
            # list of color letters such as ['k','r','k',...]
            edge_colors = tuple([colorConverter.to_rgba(c,alpha) 
                                 for c in edge_color])
        elif numpy.alltrue([not cb.is_string_like(c) 
                           for c in edge_color]):
            # If color specs are given as (rgb) or (rgba) tuples, we're OK
            if numpy.alltrue([cb.iterable(c) and len(c) in (3,4)
                             for c in edge_color]):
                edge_colors = tuple(edge_color)
            else:
                # numbers (which are going to be mapped with a colormap)
                edge_colors = None
        else:
            raise ValueError('edge_color must consist of either color names or numbers')
    else:
        if len(edge_color)==1:
            edge_colors = ( colorConverter.to_rgba(edge_color, alpha), )
        else:
            raise ValueError('edge_color must be a single color or list of exactly m colors where m is the number or edges')

    edge_collection = LineCollection(edge_pos,
                                     colors       = edge_colors,
                                     linewidths   = lw,
                                     antialiaseds = (1,),
                                     linestyle    = style,     
                                     transOffset = ax.transData,             
                                     )

    # Note: there was a bug in mpl regarding the handling of alpha values for
    # each line in a LineCollection.  It was fixed in matplotlib in r7184 and
    # r7189 (June 6 2009).  We should then not set the alpha value globally,
    # since the user can instead provide per-edge alphas now.  Only set it
    # globally if provided as a scalar.
    if cb.is_numlike(alpha):
        edge_collection.set_alpha(alpha)

    # need 0.87.7 or greater for edge colormaps
    mpl_version=matplotlib.__version__
    if mpl_version.endswith('svn'):
        mpl_version=matplotlib.__version__[0:-3]
    if mpl_version.endswith('pre'):
        mpl_version=matplotlib.__version__[0:-3]
    if map(int,mpl_version.split('.'))>=[0,87,7]:
        if edge_colors is None:
            if edge_cmap is not None: assert(isinstance(edge_cmap, Colormap))
            edge_collection.set_array(numpy.asarray(edge_color))
            edge_collection.set_cmap(edge_cmap)
            if edge_vmin is not None or edge_vmax is not None:
                edge_collection.set_clim(edge_vmin, edge_vmax)
            else:
                edge_collection.autoscale()
            pylab.sci(edge_collection)

#    else:
#        sys.stderr.write(\
#            """matplotlib version >= 0.87.7 required for colormapped edges.
#        (version %s detected)."""%matplotlib.__version__)
#        raise UserWarning(\
#            """matplotlib version >= 0.87.7 required for colormapped edges.
#        (version %s detected)."""%matplotlib.__version__)

    arrow_collection=None

    if G.is_directed() and arrows:

        # a directed graph hack
        # draw thick line segments at head end of edge
        # waiting for someone else to implement arrows that will work 
        arrow_colors = ( colorConverter.to_rgba('k', alpha), )
        a_pos=[]
        p=1.0-0.25 # make head segment 25 percent of edge length
        for src,dst in edge_pos:
            x1,y1=src
            x2,y2=dst
            dx=x2-x1 # x offset
            dy=y2-y1 # y offset
            d=numpy.sqrt(float(dx**2+dy**2)) # length of edge
            if d==0: # source and target at same position
                continue
            if dx==0: # vertical edge
                xa=x2
                ya=dy*p+y1
            if dy==0: # horizontal edge
                ya=y2
                xa=dx*p+x1
            else:
                theta=numpy.arctan2(dy,dx)
                xa=p*d*numpy.cos(theta)+x1
                ya=p*d*numpy.sin(theta)+y1
                
            a_pos.append(((xa,ya),(x2,y2)))

        arrow_collection = LineCollection(a_pos,
                                colors       = arrow_colors,
                                linewidths   = [4*ww for ww in lw],
                                antialiaseds = (1,),
                                transOffset = ax.transData,             
                                )
        
    # update view        
    minx = numpy.amin(numpy.ravel(edge_pos[:,:,0]))
    maxx = numpy.amax(numpy.ravel(edge_pos[:,:,0]))
    miny = numpy.amin(numpy.ravel(edge_pos[:,:,1]))
    maxy = numpy.amax(numpy.ravel(edge_pos[:,:,1]))

    w = maxx-minx
    h = maxy-miny
    padx, pady = 0.05*w, 0.05*h
    corners = (minx-padx, miny-pady), (maxx+padx, maxy+pady)
    ax.update_datalim( corners)
    ax.autoscale_view()

    edge_collection.set_zorder(1) # edges go behind nodes            
    ax.add_collection(edge_collection)
    if arrow_collection:
        arrow_collection.set_zorder(1) # edges go behind nodes            
        ax.add_collection(arrow_collection)

    return edge_collection
示例#18
0
def draw_edges(G,
               pos,
               ax,
               edgelist=None,
               width=1.0,
               width_adjuster=50,
               edge_color='k',
               style='solid',
               alpha=None,
               edge_cmap=None,
               edge_vmin=None,
               edge_vmax=None,
               traversal_weight=1.0,
               edge_delengthify=0.15,
               arrows=True,
               label=None,
               zorder=1,
               **kwds):
    """
    Code cleaned-up version of networkx.draw_networkx_edges

    New args:

    width_adjuster - the line width is generated from the weight if present, use this adjuster to thicken the lines (multiply)
    """
    if edgelist is None:
        edgelist = G.edges()

    if not edgelist or len(edgelist) == 0:  # no edges!
        return None

    # set edge positions
    edge_pos = [(pos[e[0]], pos[e[1]]) for e in edgelist]
    new_ep = []
    for e in edge_pos:
        x, y = e[0]
        dx, dy = e[1]

        # Get edge length
        elx = (dx - x) * edge_delengthify
        ely = (dy - y) * edge_delengthify

        x += elx
        y += ely
        dx -= elx
        dy -= ely

        new_ep.append(((x, y), (dx, dy)))
    edge_pos = numpy.asarray(new_ep)

    if not cb.iterable(width):
        #print [G.get_edge_data(n[0], n[1])['weight'] for n in edgelist]
        # see if I can find an edge attribute:
        if 'weight' in G.get_edge_data(edgelist[0][0],
                                       edgelist[0][1]):  # Test an edge
            lw = [
                0.5 +
                ((G.get_edge_data(n[0], n[1])['weight'] - traversal_weight) *
                 width_adjuster) for n in edgelist
            ]
        else:
            lw = (width, )
    else:
        lw = width

    if not is_string_like(edge_color) and cb.iterable(edge_color) and len(
            edge_color) == len(edge_pos):
        if numpy.alltrue([cb.is_string_like(c) for c in edge_color]):
            # (should check ALL elements)
            # list of color letters such as ['k','r','k',...]
            edge_colors = tuple(
                [colorConverter.to_rgba(c, alpha) for c in edge_color])
        elif numpy.alltrue([not cb.is_string_like(c) for c in edge_color]):
            # If color specs are given as (rgb) or (rgba) tuples, we're OK
            if numpy.alltrue(
                [cb.iterable(c) and len(c) in (3, 4) for c in edge_color]):
                edge_colors = tuple(edge_color)
            else:
                # numbers (which are going to be mapped with a colormap)
                edge_colors = None
        else:
            raise ValueError(
                'edge_color must consist of either color names or numbers')
    else:
        if is_string_like(edge_color) or len(edge_color) == 1:
            edge_colors = (colorConverter.to_rgba(edge_color, alpha), )
        else:
            raise ValueError(
                'edge_color must be a single color or list of exactly m colors where m is the number or edges'
            )

    edge_collection = LineCollection(edge_pos,
                                     colors=edge_colors,
                                     linewidths=lw,
                                     antialiaseds=(1, ),
                                     linestyle=style,
                                     transOffset=ax.transData,
                                     zorder=zorder)

    edge_collection.set_label(label)
    ax.add_collection(edge_collection)

    if cb.is_numlike(alpha):
        edge_collection.set_alpha(alpha)

    if edge_colors is None:
        if edge_cmap is not None:
            assert (isinstance(edge_cmap, Colormap))
        edge_collection.set_array(numpy.asarray(edge_color))
        edge_collection.set_cmap(edge_cmap)
        if edge_vmin is not None or edge_vmax is not None:
            edge_collection.set_clim(edge_vmin, edge_vmax)
        else:
            edge_collection.autoscale()

    # update view
    '''
    minx = numpy.amin(numpy.ravel(edge_pos[:,:,0]))
    maxx = numpy.amax(numpy.ravel(edge_pos[:,:,0]))
    miny = numpy.amin(numpy.ravel(edge_pos[:,:,1]))
    maxy = numpy.amax(numpy.ravel(edge_pos[:,:,1]))

    w = maxx-minx
    h = maxy-miny
    padx,  pady = 0.05*w, 0.05*h
    corners = (minx-padx, miny-pady), (maxx+padx, maxy+pady)
    ax.update_datalim(corners)
    ax.autoscale_view()
    '''
    return (edge_collection)
    def __init__(self, fig,
                 rect,
                 nrows_ncols,
                 ngrids = None,
                 direction="row",
                 axes_pad = 0.02,
                 add_all=True,
                 share_all=False,
                 aspect=True,
                 label_mode="L",
                 cbar_mode=None,
                 cbar_location="right",
                 cbar_pad=None,
                 cbar_size="5%",
                 cbar_set_cax=True,
                 axes_class=None,
                 ):
        """
        Build an :class:`ImageGrid` instance with a grid nrows*ncols
        :class:`~matplotlib.axes.Axes` in
        :class:`~matplotlib.figure.Figure` *fig* with
        *rect=[left, bottom, width, height]* (in
        :class:`~matplotlib.figure.Figure` coordinates) or
        the subplot position code (e.g., "121").

        Optional keyword arguments:

          ================  ========  =========================================
          Keyword           Default   Description
          ================  ========  =========================================
          direction         "row"     [ "row" | "column" ]
          axes_pad          0.02      float| pad between axes given in inches
          add_all           True      [ True | False ]
          share_all         False     [ True | False ]
          aspect            True      [ True | False ]
          label_mode        "L"       [ "L" | "1" | "all" ]
          cbar_mode         None      [ "each" | "single" | "edge" ]
          cbar_location     "right"   [ "left" | "right" | "bottom" | "top" ]
          cbar_pad          None
          cbar_size         "5%"
          cbar_set_cax      True      [ True | False ]
          axes_class        None      a type object which must be a subclass
                                      of :class:`~matplotlib.axes.Axes`
          ================  ========  =========================================

        *cbar_set_cax* : if True, each axes in the grid has a cax
          attribute that is bind to associated cbar_axes.
        """
        self._nrows, self._ncols = nrows_ncols

        if ngrids is None:
            ngrids = self._nrows * self._ncols
        else:
            if (ngrids > self._nrows * self._ncols) or  (ngrids <= 0):
                raise Exception("")

        self.ngrids = ngrids

        self._axes_pad = axes_pad

        self._colorbar_mode = cbar_mode
        self._colorbar_location = cbar_location
        if cbar_pad is None:
            self._colorbar_pad = axes_pad
        else:
            self._colorbar_pad = cbar_pad

        self._colorbar_size = cbar_size

        self._init_axes_pad(axes_pad)

        if direction not in ["column", "row"]:
            raise Exception("")

        self._direction = direction


        if axes_class is None:
            axes_class = self._defaultLocatableAxesClass
            axes_class_args = {}
        else:
            if isinstance(axes_class, maxes.Axes):
                axes_class_args = {}
            else:
                axes_class, axes_class_args = axes_class



        self.axes_all = []
        self.axes_column = [[] for i in range(self._ncols)]
        self.axes_row = [[] for i in range(self._nrows)]

        self.cbar_axes = []

        h = []
        v = []
        if cbook.is_string_like(rect) or cbook.is_numlike(rect):
            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
                                           aspect=aspect)
        elif isinstance(rect, SubplotSpec):
            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
                                           aspect=aspect)
        elif len(rect) == 3:
            kw = dict(horizontal=h, vertical=v, aspect=aspect)
            self._divider = SubplotDivider(fig, *rect, **kw)
        elif len(rect) == 4:
            self._divider = Divider(fig, rect, horizontal=h, vertical=v,
                                    aspect=aspect)
        else:
            raise Exception("")


        rect = self._divider.get_position()

        # reference axes
        self._column_refax = [None for i in range(self._ncols)]
        self._row_refax = [None for i in range(self._nrows)]
        self._refax = None

        for i in range(self.ngrids):

            col, row = self._get_col_row(i)

            if share_all:
                sharex = self._refax
                sharey = self._refax
            else:
                sharex = self._column_refax[col]
                sharey = self._row_refax[row]

            ax = axes_class(fig, rect, sharex=sharex, sharey=sharey,
                            **axes_class_args)

            if share_all:
                if self._refax is None:
                    self._refax = ax
            else:
                if sharex is None:
                    self._column_refax[col] = ax
                if sharey is None:
                    self._row_refax[row] = ax

            self.axes_all.append(ax)
            self.axes_column[col].append(ax)
            self.axes_row[row].append(ax)

            cax = self._defaultCbarAxesClass(fig, rect,
                                             orientation=self._colorbar_location)
            self.cbar_axes.append(cax)

        self.axes_llc = self.axes_column[0][-1]

        self._update_locators()

        if add_all:
            for ax in self.axes_all+self.cbar_axes:
                fig.add_axes(ax)

        if cbar_set_cax:
            if self._colorbar_mode == "single":
                for ax in self.axes_all:
                    ax.cax = self.cbar_axes[0]
            else:
                for ax, cax in zip(self.axes_all, self.cbar_axes):
                    ax.cax = cax

        self.set_label_mode(label_mode)
    def __init__(self, fig,
                 rect,
                 nrows_ncols,
                 ngrids = None,
                 direction="row",
                 axes_pad = 0.02,
                 add_all=True,
                 share_all=False,
                 share_x=True,
                 share_y=True,
                 #aspect=True,
                 label_mode="L",
                 axes_class=None,
                 ):
        """
        Build an :class:`Grid` instance with a grid nrows*ncols
        :class:`~matplotlib.axes.Axes` in
        :class:`~matplotlib.figure.Figure` *fig* with
        *rect=[left, bottom, width, height]* (in
        :class:`~matplotlib.figure.Figure` coordinates) or
        the subplot position code (e.g., "121").

        Optional keyword arguments:

          ================  ========  =========================================
          Keyword           Default   Description
          ================  ========  =========================================
          direction         "row"     [ "row" | "column" ]
          axes_pad          0.02      float| pad between axes given in inches
          add_all           True      [ True | False ]
          share_all         False     [ True | False ]
          share_x           True      [ True | False ]
          share_y           True      [ True | False ]
          label_mode        "L"       [ "L" | "1" | "all" ]
          axes_class        None      a type object which must be a subclass
                                      of :class:`~matplotlib.axes.Axes`
          ================  ========  =========================================
        """
        self._nrows, self._ncols = nrows_ncols

        if ngrids is None:
            ngrids = self._nrows * self._ncols
        else:
            if (ngrids > self._nrows * self._ncols) or  (ngrids <= 0):
                raise Exception("")

        self.ngrids = ngrids

        self._init_axes_pad(axes_pad)

        if direction not in ["column", "row"]:
            raise Exception("")

        self._direction = direction


        if axes_class is None:
            axes_class = self._defaultLocatableAxesClass
            axes_class_args = {}
        else:
            if (type(axes_class)) == type and \
                   issubclass(axes_class, self._defaultLocatableAxesClass.Axes):
                axes_class_args = {}
            else:
                axes_class, axes_class_args = axes_class

        self.axes_all = []
        self.axes_column = [[] for i in range(self._ncols)]
        self.axes_row = [[] for i in range(self._nrows)]


        h = []
        v = []
        if cbook.is_string_like(rect) or cbook.is_numlike(rect):
            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
                                           aspect=False)
        elif isinstance(rect, SubplotSpec):
            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v,
                                           aspect=False)
        elif len(rect) == 3:
            kw = dict(horizontal=h, vertical=v, aspect=False)
            self._divider = SubplotDivider(fig, *rect, **kw)
        elif len(rect) == 4:
            self._divider = Divider(fig, rect, horizontal=h, vertical=v,
                                    aspect=False)
        else:
            raise Exception("")


        rect = self._divider.get_position()

        # reference axes
        self._column_refax = [None for i in range(self._ncols)]
        self._row_refax = [None for i in range(self._nrows)]
        self._refax = None

        for i in range(self.ngrids):

            col, row = self._get_col_row(i)

            if share_all:
                sharex = self._refax
                sharey = self._refax
            else:
                if share_x:
                    sharex = self._column_refax[col]
                else:
                    sharex = None

                if share_y:
                    sharey = self._row_refax[row]
                else:
                    sharey = None

            ax = axes_class(fig, rect, sharex=sharex, sharey=sharey,
                            **axes_class_args)

            if share_all:
                if self._refax is None:
                    self._refax = ax
            else:
                if sharex is None:
                    self._column_refax[col] = ax
                if sharey is None:
                    self._row_refax[row] = ax

            self.axes_all.append(ax)
            self.axes_column[col].append(ax)
            self.axes_row[row].append(ax)

        self.axes_llc = self.axes_column[0][-1]

        self._update_locators()

        if add_all:
            for ax in self.axes_all:
                fig.add_axes(ax)

        self.set_label_mode(label_mode)
示例#21
0
def draw_nx_tapered_edges(G,
                          pos,
                          edgelist=None,
                          width=0.5,
                          edge_color='k',
                          style='solid',
                          alpha=1.0,
                          edge_cmap=None,
                          edge_vmin=None,
                          edge_vmax=None,
                          ax=None,
                          label=None,
                          highlight=None,
                          tapered=False,
                          **kwds):
    """Draw the edges of the graph G.
    This draws only the edges of the graph G.
    Parameters
    ----------
    G : graph
       A networkx graph
    pos : dictionary
       A dictionary with nodes as keys and positions as values.
       Positions should be sequences of length 2.
    edgelist : collection of edge tuples
       Draw only specified edges(default=G.edges())
    width : float, or array of floats
       Line width of edges (default=1.0)
    edge_color : color string, or array of floats
       Edge color. Can be a single color format string (default='r'),
       or a sequence of colors with the same length as edgelist.
       If numeric values are specified they will be mapped to
       colors using the edge_cmap and edge_vmin,edge_vmax parameters.
    style : string
       Edge line style (default='solid') (solid|dashed|dotted,dashdot)
    alpha : float
       The edge transparency (default=1.0)
    edge_ cmap : Matplotlib colormap
       Colormap for mapping intensities of edges (default=None)
    edge_vmin,edge_vmax : floats
       Minimum and maximum for edge colormap scaling (default=None)
    ax : Matplotlib Axes object, optional
       Draw the graph in the specified Matplotlib axes.
    label : [None| string]
       Label for legend
    Returns
    -------
    matplotlib.collection.LineCollection
        `LineCollection` of the edges
    Examples
    --------
    >>> G=nx.dodecahedral_graph()
    >>> edges=nx.draw_networkx_edges(G,pos=nx.spring_layout(G))
    Also see the NetworkX drawing examples at
    http://networkx.github.io/documentation/latest/gallery.html
    See Also
    --------
    draw()
    draw_networkx()
    draw_networkx_nodes()
    draw_networkx_labels()
    draw_networkx_edge_labels()
    """
    if ax is None:
        ax = plt.gca()

    if edgelist is None:
        edgelist = list(G.edges())

    if not edgelist or len(edgelist) == 0:  # no edges!
        return None

    if highlight is not None and (isinstance(edge_color, basestring)
                                  or not cb.iterable(edge_color)):
        idMap = {}
        nodes = G.nodes()
        for i in range(len(nodes)):
            idMap[nodes[i]] = i
        ecol = [edge_color] * len(edgelist)
        eHighlight = [
            highlight[idMap[edge[0]]] or highlight[idMap[edge[1]]]
            for edge in edgelist
        ]
        for i in range(len(eHighlight)):
            if eHighlight[i]:
                ecol[i] = '0.0'
        edge_color = ecol

    # set edge positions
    if not cb.iterable(width):
        lw = np.full(len(edgelist), width)
    else:
        lw = width

    edge_pos = []
    wdScale = 0.01
    for i in range(len(edgelist)):
        e = edgelist[i]
        w = wdScale * lw[i] / 2
        p0 = pos[e[0]]
        p1 = pos[e[1]]
        dx = p1[0] - p0[0]
        dy = p1[1] - p0[1]
        l = math.sqrt(dx * dx + dy * dy)
        edge_pos.append(
            ((p0[0] + w * dy / l, p0[1] - w * dx / l),
             (p0[0] - w * dy / l, p0[1] + w * dx / l), (p1[0], p1[1])))

    edge_vertices = np.asarray(edge_pos)

    if not isinstance(edge_color, basestring) \
           and cb.iterable(edge_color) \
           and len(edge_color) == len(edge_vertices):
        if np.alltrue([isinstance(c, basestring) for c in edge_color]):
            # (should check ALL elements)
            # list of color letters such as ['k','r','k',...]
            edge_colors = tuple(
                [colorConverter.to_rgba(c, alpha) for c in edge_color])
        elif np.alltrue([not isinstance(c, basestring) for c in edge_color]):
            # If color specs are given as (rgb) or (rgba) tuples, we're OK
            if np.alltrue(
                [cb.iterable(c) and len(c) in (3, 4) for c in edge_color]):
                edge_colors = tuple(edge_color)
            else:
                # numbers (which are going to be mapped with a colormap)
                edge_colors = None
        else:
            raise ValueError(
                'edge_color must consist of either color names or numbers')
    else:
        if isinstance(edge_color, basestring) or len(edge_color) == 1:
            edge_colors = (colorConverter.to_rgba(edge_color, alpha), )
        else:
            raise ValueError(
                'edge_color must be a single color or list of exactly m colors where m is the number or edges'
            )

    if tapered:
        edge_collection = PolyCollection(
            edge_vertices,
            facecolors=edge_colors,
            linewidths=0,
            antialiaseds=(1, ),
            transOffset=ax.transData,
        )
    else:
        edge_collection = LineCollection(
            edge_pos,
            colors=edge_colors,
            linewidths=lw,
            antialiaseds=(1, ),
            linestyle=style,
            transOffset=ax.transData,
        )

    edge_collection.set_zorder(1)  # edges go behind nodes
    edge_collection.set_label(label)
    ax.add_collection(edge_collection)

    # Note: there was a bug in mpl regarding the handling of alpha values for
    # each line in a LineCollection.  It was fixed in matplotlib in r7184 and
    # r7189 (June 6 2009).  We should then not set the alpha value globally,
    # since the user can instead provide per-edge alphas now.  Only set it
    # globally if provided as a scalar.
    if cb.is_numlike(alpha):
        edge_collection.set_alpha(alpha)

    if edge_colors is None:
        if edge_cmap is not None:
            assert (isinstance(edge_cmap, Colormap))
        edge_collection.set_array(np.asarray(edge_color))
        edge_collection.set_cmap(edge_cmap)
        if edge_vmin is not None or edge_vmax is not None:
            edge_collection.set_clim(edge_vmin, edge_vmax)
        else:
            edge_collection.autoscale()

    # update view
    minx = np.amin(np.ravel(edge_vertices[:, :, 0]))
    maxx = np.amax(np.ravel(edge_vertices[:, :, 0]))
    miny = np.amin(np.ravel(edge_vertices[:, :, 1]))
    maxy = np.amax(np.ravel(edge_vertices[:, :, 1]))

    w = maxx - minx
    h = maxy - miny
    padx, pady = 0.05 * w, 0.05 * h
    corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
    ax.update_datalim(corners)
    ax.autoscale_view()

    return edge_collection
示例#22
0
    def __init__(self,
                 abu = None,
                 stable = True,
                 xlim = None,
                 ylim = None,
                 y_min = 1e-15,
                 truncate = True,
                 truncate_limit = 1e-23,
                 ax = None,
                 colors = None,
                 linewidth = 1.,
                 markersize = 8.,
                 markerfill = True,
                 markerthick = 1.,
                 fp = None,
                 fontsize = None,
                 pathfont = False,
                 xborder = 0.025,
                 yborder = 0.1,
                 title = None,
                 xtitle = 'mass number',
                 ytitle = 'mass fraction',
                 logy = True,
                 showline = True,
                 showmarker = True,
                 pathmarker = False,
                 align = 'center', # center | first | last
                 showtext = True,
                 stabletext = True,
                 dist = 0.25,
                 norm = None,
                 normtype = 'span',
                 show = None,
                 normrange = 2.):
        """
        Make isotopic abuncance plot.

        abu:
            AbuSet instance

        TODO
        - assert ions sorted?
        - Some of the parameters should be renamed for consistency.
        (under development)
        """
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)
            if show is None:
                show = True
        else:
            if show is None:
                show = False

        if colors is None:
            colors=self.colors
        ncolors = len(colors)

        z = abu.Z()
        a = abu.A()
        x = abu.X()

        if stable is not None:
            if stable is True:
                sol = SolAbu()
                stable = sol.contains(abu)
            elif stable is not False:
                assert len(a) == len(stable)
            else:
                stable = None

        if x.min() <= 1.e-15 and logy:
            y_min = 1.e-15
        else:
            y_min = x.min()
        if logy:
            x = np.log10(np.maximum(x, 1e-99))
            y_min = np.log10(y_min)
            truncate_limit = np.log10(truncate_limit)

        step, =  np.where(z[1:] != z[:-1])
        nz = len(step) + 1
        seq = np.zeros(nz + 1, dtype = np.int64)
        seq[1:-1] = step + 1
        seq[-1] = len(z)

        # we do need limit being set for alignemt use, font scale, etc.
        xlim = self.get_xlim(xlim, xborder, a)

        if ylim is None:
            ylim = np.array([y_min,x.max()])
            ylim += np.array([-1,1]) * yborder * (ylim[1]-ylim[0])
        elif len(ylim) == 1:
            ylim = np.array([ylim, x.max()])
            ylim += np.array([0,1]) * yborder * (ylim[1]-ylim[0])
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.set_autoscalex_on(False)

        ii = -1
        for iz in range(nz):
            na = seq[iz + 1] - seq[iz]
            az = np.ndarray(na)
            xz = np.ndarray(na)
            fz = np.ndarray(na, dtype = np.bool)
            for ia in range(na):
                ii += 1
                az[ia] = a[ii]
                xz[ia] = x[ii]
                fz[ia] = markerfill if stable is None else stable[ii]
            if truncate:
                mi, = np.where(xz > truncate_limit)
                if len(mi) == 0:
                    continue
                az, xz, fz = az[mi], xz[mi], fz[mi]
            color = colors[np.mod(iz,ncolors)]
            if showline:
                line = plt.Line2D(
                    az, xz,
                    color = color,
                    linewidth = linewidth)
                ax.add_line(line)
            if showmarker:
                if pathmarker:
                    Pm = Path.MOVETO
                    Pl = Path.LINETO
                    Pc = Path.CLOSEPOLY
                    for mx,my,mf in zip(az, xz, fz):
                        mpos = ax.transData.transform([[mx,my]])
                        nvert = 64
                        mvert = np.linspace(0,2*np.pi,nvert)
                        p = (np.array((np.sin(mvert),
                                       np.cos(mvert))).transpose()
                             * 0.5*markersize)
                        p += mpos
                        p = ax.transData.inverted().transform(p)
                        c = [Pm] + (nvert-2)*[Pl] + [Pc]

                        path = Path(p, codes = c)
                        if mf:
                            patch = PathPatch(
                                path,
                                clip_on = True,
                                facecolor = color,
                                edgecolor = 'none',
                                linewidth = 0,
                                alpha = 1)
                        else:
                            patch = PathPatch(
                                path,
                                clip_on = True,
                                facecolor = 'none',
                                edgecolor = color,
                                linewidth = markerthick,
                                alpha = 1)
                        ax.add_patch(patch)
                else:
                    for mf in [True, False]:
                        ma = az[fz == mf]
                        mx = xz[fz == mf]
                        if len(ma) == 0:
                            continue
                        if mf:
                            markeredgecolor = 'none'
                            markeredgewidth = 0.
                            markerfacecolor = color
                        else:
                            markeredgecolor = color
                            markeredgewidth = markerthick
                            markerfacecolor = 'none'
                        line = plt.Line2D(
                                ma, mx,
                                marker = 'o',
                                markeredgecolor = markeredgecolor,
                                markeredgewidth = markeredgewidth,
                                linewidth = 0.,
                                markersize = markersize,
                                markerfacecolor = markerfacecolor)
                        ax.add_line(line)

            s = Elements[z[seq[iz]]]

            if fp is None:
                fp = FontProperties(
                    size = fontsize)

            if showtext:
                if stabletext and np.count_nonzero(fz) > 0:
                    mi, = np.nonzero(fz)
                    ms = slice(mi[0], mi[-1]+1)
                    ma, mx = az[ms], xz[ms]
                else:
                    ma, mx = az, xz
                dd,ha,va = self.align(
                    ma, mx, markersize, ax, align, dist,
                    pathfont, pathmarker)
                if pathfont:
                    # These would be useful to paint fonts directly as PathPatch
                    text = self.get_text(s, fp)
                    fxmin, fymin = text.vertices.min(axis=0)
                    fxmax, fymax = text.vertices.max(axis=0)
                    fwidth = fxmax - fxmin
                    fheight = fymax - fymin

                    # dependency on alignment
                    if not is_numlike(va):
                        if va == 'center':
                            va = 0.5
                        elif va == 'top':
                            va = 1.
                        else:
                            va = 0.
                    if not is_numlike(ha):
                        if ha == 'center':
                            ha = 0.5
                        elif ha == 'left':
                            ha = 0.
                        else:
                            ha = 1.
                    y_offset = -fymin - fheight * va
                    x_offset = -fxmin - fwidth  * ha

                    # set target position
                    fx, fy = dd

                    # now we need to find size (assume centered)
                    fext = np.array([[fwidth, fheight]])
                    fpos = ax.transData.transform([[fx,fy]])
                    frange = ax.transData.inverted().transform(
                        np.vstack((fpos - 0.5*fext, fpos + 0.5*fext)))
                    fscale = np.abs(frange[1]-frange[0])/fext.reshape(-1)

                    p = (text.vertices  + [[x_offset, y_offset]]) * [fscale] + [[fx, fy]]
                    c = text.codes

                    path = Path(p, codes = c)
                    patch = PathPatch(
                        path,
                        clip_on = True,
                        facecolor = 'k',
                        edgecolor = 'none',
                        lw = 0,
                        alpha = 1,
                        zorder = 3)
                    ax.add_patch(patch)
                else:
                    text = Text(dd[0],dd[1],s,
                                va = va,
                                ha = ha,
                                clip_on = True,
                                fontproperties = fp,
                                zorder = 3)
                    ax.add_artist(text)


        if ytitle is not None:
            if logy:
                ytitle = 'log( ' + ytitle + ' )'
            ax.set_ylabel(ytitle)
        if xtitle is not None:
            ax.set_xlabel(xtitle)
        if title is not None:
            ax.set_title(title)

        ax.set_xscale('linear')
        ax.set_yscale('linear')

        ax.xaxis.set_minor_locator(AutoMinorLocator())
        ax.yaxis.set_minor_locator(AutoMinorLocator())

        if norm is not None:
            if not is_numlike(norm):
                if isinstance(norm, str):
                    norm = Ion(norm)
                if isinstance(norm, Ion):
                    norm = abu[norm]
            if logy:
                norm = np.log10(norm)

            if normtype == 'line':
                lines = ['--',':']
                colors = ['k','k']
            else:
                lines = ['-','-']
                colors = ['#404040','#E0E0E0']
                if normrange is not None:
                    colors[0] = 'w'

            ax.axhline(norm,
                       linestyle = lines[0],
                       color = colors[0],
                       zorder = -1)
            if normrange is not None:
                if np.size(normrange) == 1:
                    if logy:
                        normrange = norm + np.log10(np.array([1./normrange,normrange]))
                    else:
                        normrange = norm + np.array([-normrange,+normrange])
                    if normtype == 'line':
                        ax.axhline(normrange[0],
                                   linestyle = lines[1],
                                   color = colors[1],
                                   zorder = -1)
                        ax.axhline(normrange[1],
                                   linestyle = lines[1],
                                   color = colors[1],
                                   zorder = -1)
                    else:
                        ax.axhspan(*normrange,
                                    color = colors[1],
                                    zorder = -2)

            #  ax.set_position((0.078,0.07,0.921,0.929))

        self.ax = ax
        self.figure = ax.figure

        if show:
            plt.draw()
示例#23
0
def kepcotrend(infile, bvfile, listbv, outfile=None, fitmethod='llsq',
               fitpower=1, iterate=False, sigma=None, maskfile='',
               scinterp='linear', plot=False, noninteractive=False,
               overwrite=False, verbose=False, logfile='kepcotrend.log'):
    """
    kepcotrend -- Remove systematic trends Kepler light curves using
    cotrending basis vectors. The cotrending basis vectors files can be found
    here: http://archive.stsci.edu/kepler/cbv.html

    Simple Aperture Photometry (SAP) data often contain systematic trends
    associated with the spacecraft, detector and environment rather than the
    target. See the the Kepler data release notes for descriptions of
    systematics and the cadences that they affect. Within the Kepler pipeline
    these contaminants are treated during Pre-search Data Conditioning (PDC)
    and cleaned data are provided in the light curve files archived at MAST
    within the column PDCSAP_FLUX. The Kepler pipeline attempts to remove
    systematics with a combination of data detrending and cotrending against
    engineering telemetry from the spacecraft such as detector temperatures.
    These processes are imperfect but tackled in the spirit of correcting as
    many targets as possible with enough accuracy for the mission to meet
    exoplanet detection specifications.

    The imperfections in the method are most apparent in variable stars, those
    stars that are of most interest for stellar astrophysics. The PDC
    correction can occasionally hamper data analysis or, at worst, destroy
    astrophysical signal from the target. While data filtering (``kepoutlier``,
    ``kepfilter``) and data detrending with analytical functions
    (``kepdetrend``) often provide some mitigation for data artifacts, these
    methods require assumptions and often result in lossy data. An alternative
    viable approach is to identify the photometric variability common to all of
    the stars neighboring the target and subtract those trends from the target.
    In principle, the correct choice, weighting and subtraction of these common
    trends will leave behind a corrected flux time series which better
    represents statistically the true signal from the target.

    While GOs, KASC members and archive users wait for the Kepler project to
    release quarters of data, they do not have access to all the light curve
    data neighboring their targets and so cannot take the ensemble approach
    themselves without help. To mitigate this problem the Kepler Science Office
    have made available ancillary data which describes the systematic trends
    present in the ensemble flux data for each CCD channel. These data are
    known as the Cotrending Basis Vectors (CBVs). More details on the method
    used to generate these basis vectors will be provided in the Kepler Data
    Processing Handbook soon, but until that time, a summary of the method is
    given here. To create the initial basis set, that is the flux time series'
    that are used to make the cotrending basis vectors:

    The time series photometry of each star on a specific detector channel is
    normalized by its own median flux. One (unity) is subtracted from each time
    series so that the median value of the light curve is zero.
    The time series is divided by the root-mean square of the photometry.
    The correlation between each time series on the CCD channel is calculated
    using the median and root-mean square normalized flux.
    The median absolute correlation is then calculated for each star.
    All stars on the channel are sorted into ascending order of correlation.
    The 50 percent most correlated stars are selected.
    The median normalized fluxes only (as opposed to the root-mean square
    normalized fluxes) are now used for the rest of the process Singular Value
    Decomposition is applied to the matrix of correlated sources to create
    orthonormal basis vectors from the U matrix, sorted by their singular
    values.

    The archived cotrending basis vectors are a reduced-rank representation of
    the full set of basis vectors and consist of the 16 leading columns.

    To correct a SAP light curve, :math:`Fsap`, for systematic features,
    ``kepcotrend`` employs the cotrending basis vectors :math:`CBVi`. The task
    finds the coefficients :math:`Ai` which minimize

    .. math::

        Fcbv = Fsap - \sum_{i} Ai \cdot CBV_i

    The corrected light curve, Fcbv, can be tailored to the needs of the user
    and their scientific objective. The user decides which combination of basis
    vectors best removes systematics from their specific Kepler SAP light
    curve. In principle the user can choose any combination of cotrending basis
    vectors to fit to the data. However, experience suggests that most choices
    will be to decide how many sequential basis vectors to include in the fit,
    starting with first vector. For example a user is much more likely to
    choose a vector combination 1, 2, 3, 4, 5, 6 etc over e.g. a combination 1,
    2, 5, 7, 8, 10, 12. The user should always include at least the first two
    basis vectors. The number of basis vectors used is directly related to the
    scientific aims of the user and the light curve being analyzed and
    experimental iteration towards a target-specific optimal basis set is
    recommended. Occasionally kepcotrend over-fits the data and removes real
    astrophysical signal. This is particularly prevalent if too many basis
    vectors are used. A good rule of thumb is to start with two basis vectors
    and increase the number until there is no improvement, or signals which are
    thought to be astrophysical start to become distorted.

    The user is given a choice of fitting algorithm to use. For most purposes
    the linear least squares method is both the fastest and the most accurate
    because it gives the exact solution to the least squares problem. However
    we have found a few situations where the best solution, scientifically,
    comes from using the simplex fitting algorithm which performs something
    other than a least squares fit. Performing a least absolute residuals fit
    (fitpower=1.0), for example, is more robust to outliers.

    There are instances when the fit performs sub-optimally due to the presence
    of certain events in the light curve. For this reason we have included two
    options which can be used individually or simultaneously to improve the fit
    - iterative fitting and data masking. Iterative fitting performs the fit
    and rejects data points that are greater than a specified distance from the
    optimal fit before re-fitting. The lower threshold for data clipping is
    provided by the user as the number of sigma from the best fit. The clipping
    threshold is more accurately defined as the number of Median Absolute
    Deviations (MADs) multiplied by 1.4826. The distribution of MAD will be
    identical to the distribution of standard deviation if the distribution is
    Gaussian. We use MAD because in highly non-Gaussian distributions MAD is
    more robust to outliers than standard deviation.

    The code will print out the coefficients fit to each basis vector, the
    root-mean square of the fit and the chi-squared value of the fit. The rms
    and the chi-squared value include only the data points included in the fit
    so if an iterative fit is performed these clipped values are not included
    in this calculation.

    Parameters
    ----------
    infile : str
        the input file in the FITS format obtained from MAST
    outfile : str
        the output will be a fits file in the same style as the input file but
        with two additional columns: CBVSAP_MODL and CBVSAP_FLUX. The first of
        these is the best fitting linear combination of basis vectors.
        The second is the new flux with the basis vector sum subtracted. This
        is the new flux value.
    bvfile : str
        the name of the FITS file containing the basis vectors
    listbv : list of integers
        the basis vectors to fit to the data
    fitmethod : str
        fit using either the 'llsq' or the 'simplex' method. 'llsq' is usually
        the correct one to use because as the basis vectors are orthogonal.
        Simplex gives you option of using a different merit function - ie. you
        can minimise the least absolute residual instead of the least squares
        which weights outliers less
    fitpower : float
        if using a simplex you can chose your own power in the metir function
        - i.e. the merit function minimises :math:`abs(Obs - Mod)^P`.
        :math:`P = 2` is least squares, :math:`P = 1` minimises least absolutes
    iterate : bool
        should the program fit the basis vectors to the light curve data then
        remove data points further than 'sigma' from the fit and then refit
    maskfile : str
        this is the name of a mask file which can be used to define regions of
        the flux time series to exclude from the fit. The easiest way to create
        this is by using ``keprange`` from the PyKE set of tools. You can also
        make this yourself with two BJDs on each line in the file specifying
        the beginning and ending date of the region to exclude.
    scinterp : str
        the basis vectors are only calculated for long cadence data, therefore if
        you want to use short cadence data you have to interpolate the basis
        vectors. There are several methods to do this, the best of these probably
        being nearest which picks the value of the nearest long cadence data
        point.
        The options available are:

        * linear
        * nearest
        * zero
        * slinear
        * quadratic
        * cubic
    plot : bool
        Plot the data and result?
    non-interactive : bool
        If True, prevents the matplotlib window to pop up.
    overwrite : bool
        Overwrite the output file?
    verbose : bool
        Print informative messages and warnings to the shell and logfile?
    logfile : str
        Name of the logfile containing error and warning messages.

    Examples
    --------
    .. code-block:: bash

        $ kepcotrend kplr005110407-2009350155506_llc.fits ~/cbv/kplr2009350155506-q03-d25_lcbv.fits
        '1 2 3' --plot --verbose
    """

    if outfile is None:
        outfile = infile.split('.')[0] + "-{}.fits".format(__all__[0])
    # log the call
    hashline = '--------------------------------------------------------------'
    kepmsg.log(logfile, hashline, verbose)
    call = ('KEPCOTREND -- '
            + ' infile={}'.format(infile)
            + ' outfile={}'.format(outfile)
            + ' bvfile={}'.format(bvfile)
            + ' listbv={} '.format(listbv)
            + ' fitmethod={}'.format(fitmethod)
            + ' fitpower={}'.format(fitpower)
            + ' iterate={}'.format(iterate)
            + ' sigma_clip={}'.format(sigma)
            + ' mask_file={}'.format(maskfile)
            + ' scinterp={}'.format(scinterp)
            + ' plot={}'.format(plot)
            + ' overwrite={}'.format(overwrite)
            + ' verbose={}'.format(verbose)
            + ' logfile={}'.format(logfile))
    kepmsg.log(logfile, call+'\n', verbose)

    # start time
    kepmsg.clock('KEPCOTREND started at', logfile, verbose)

    # overwrite output file
    if overwrite:
        kepio.overwrite(outfile, logfile, verbose)
    if kepio.fileexists(outfile):
        errmsg = 'ERROR -- KEPCOTREND: {} exists. Use --overwrite'.format(outfile)
        kepmsg.err(logfile, errmsg, verbose)

    # open input file
    instr = pyfits.open(infile)
    tstart, tstop, bjdref, cadence = kepio.timekeys(instr, infile, logfile,
                                                    verbose)
    # fudge non-compliant FITS keywords with no values
    instr = kepkey.emptykeys(instr, infile, logfile, verbose)

    if not kepio.fileexists(bvfile):
        message = 'ERROR -- KEPCOTREND: ' + bvfile + ' does not exist.'
        kepmsg.err(logfile, message, verbose)
    #lsq_sq - nonlinear least squares fitting and simplex_abs have been
    #removed from the options in PyRAF but they are still in the code!
    if fitmethod not in ['llsq','matrix','lst_sq','simplex_abs','simplex']:
        errmsg = 'Fit method must either: llsq, matrix, lst_sq or simplex'
        kepmsg.err(logfile, errmsg, verbose)

    if not is_numlike(fitpower) and fitpower is not None:
        errmsg = 'Fit power must be an real number or None'
        kepmsg.err(logfile, errmsg, verbose)

    if fitpower is None:
        fitpower = 1.

    # input data
    short = False
    try:
        test = str(instr[0].header['FILEVER'])
        version = 2
    except KeyError:
        version = 1

    table = instr[1].data
    if version == 1:
        if str(instr[1].header['DATATYPE']) == 'long cadence':
            quarter = str(instr[1].header['QUARTER'])
            module = str(instr[1].header['MODULE'])
            output = str(instr[1].header['OUTPUT'])
            channel = str(instr[1].header['CHANNEL'])
            lc_cad_o = table.field('cadence_number')
            lc_date_o = table.field('barytime')
            lc_flux_o = table.field('ap_raw_flux') / 1625.3468 #convert to e-/s
            lc_err_o = table.field('ap_raw_err') / 1625.3468 #convert to e-/s
        elif str(instr[1].header['DATATYPE']) == 'short cadence':
            short = True
            quarter = str(instr[1].header['QUARTER'])
            module = str(instr[1].header['MODULE'])
            output = str(instr[1].header['OUTPUT'])
            channel = str(instr[1].header['CHANNEL'])
            lc_cad_o = table.field('cadence_number')
            lc_date_o = table.field('barytime')
            lc_flux_o = table.field('ap_raw_flux') / 54.178 #convert to e-/s
            lc_err_o = table.field('ap_raw_err') / 54.178 #convert to e-/s

    elif version >= 2:
        if str(instr[0].header['OBSMODE']) == 'long cadence':
            quarter = str(instr[0].header['QUARTER'])
            module = str(instr[0].header['MODULE'])
            output = str(instr[0].header['OUTPUT'])
            channel = str(instr[0].header['CHANNEL'])
            lc_cad_o = table.field('CADENCENO')
            lc_date_o = table.field('TIME')
            lc_flux_o = table.field('SAP_FLUX')
            lc_err_o = table.field('SAP_FLUX_ERR')
        elif str(instr[0].header['OBSMODE']) == 'short cadence':
            short = True
            quarter = str(instr[0].header['QUARTER'])
            module = str(instr[0].header['MODULE'])
            output = str(instr[0].header['OUTPUT'])
            channel = str(instr[0].header['CHANNEL'])
            lc_cad_o = table.field('CADENCENO')
            lc_date_o = table.field('TIME')
            lc_flux_o = table.field('SAP_FLUX')
            lc_err_o = table.field('SAP_FLUX_ERR')

    if str(quarter) == str(4) and version == 1:
        lc_cad_o = lc_cad_o[lc_cad_o >= 11914]
        lc_date_o = lc_date_o[lc_cad_o >= 11914]
        lc_flux_o = lc_flux_o[lc_cad_o >= 11914]
        lc_err_o = lc_err_o[lc_cad_o >= 11914]

    if short and scinterp == None:
        errmsg = ('You cannot select None as the interpolation method '
                  'because you are using short cadence data and '
                  'therefore must use some form of interpolation. I '
                  'reccommend nearest if you are unsure.')
        kepmsg.err(logfile, errmsg, verbose)

    bvfiledata = pyfits.open(bvfile)
    bvdata = bvfiledata['MODOUT_{0}_{1}'.format(module, output)].data

    if int(bvfiledata[0].header['QUARTER']) != int(quarter):
        errmsg = ('CBV file and light curve file are from different '
                  'quarters. CBV file is from Q{0} and light curve is '
                  'from Q{1}'.format(int(bvfiledata[0].header['QUARTER']),
                                     int(quarter)))
        kepmsg.err(logfile, errmsg, verbose)

    if int(quarter) == 4 and int(module) == 3:
        errmsg = ('Approximately twenty days into Q4 Module 3 failed. '
                  'As a result, Q4 light curves contain these 20 day '
                  'of data. However, we do not calculate CBVs for '
                  'this section of data.')
        kepmsg.err(logfile, errmsg, verbose)

    #cut out infinites and zero flux columns
    lc_cad, lc_date, lc_flux, lc_err, good_data = cut_bad_data(lc_cad_o,
                                      lc_date_o, lc_flux_o, lc_err_o)
    #get a list of basis vectors to use from the list given
    #accept different seperators
    if len(listbv) == 1:
        bvlist = [listbv]
    else:
        listbv = listbv.strip()
        if listbv[1] in [' ', ',', ':', ';', '|', ', ']:
            separator = str(listbv)[1]
        else:
            message = ('You must separate your basis vector numbers to use '
                       'with \' \' \',\' \':\' \';\' or \'|\' and the '
                       'first basis vector to use must be between 1 and 9')
            kepmsg.err(logfile, message, verbose)

        bvlist = np.fromstring(listbv, dtype=int, sep=separator)

    if bvlist[0] == 0:
        errmsg = 'Must use at least one basis vector'
        kepmsg.err(logfile, errmsg, verbose)
    if short:
        bvdata.field('CADENCENO')[:] = ((((bvdata.field('CADENCENO')[:] +
                                        (7.5 / 15.) ) * 30.) - 11540.).round())
    bvectors = get_pcomp_list_newformat(bvdata, bvlist, lc_cad, short, scinterp)
    medflux = np.median(lc_flux)
    n_flux = (lc_flux / medflux) - 1
    n_err = np.sqrt(lc_err * lc_err / (medflux * medflux))

    if maskfile != '':
        domasking = True
        if not kepio.fileexists(maskfile):
            errmsg = 'Maskfile {} does not exist'.format(maskfile)
            kepmsg.err(logfile, errmsg, verbose)
    else:
        domasking = False

    if domasking:
        lc_date_masked = np.copy(lc_date)
        n_flux_masked = np.copy(n_flux)
        lc_cad_masked = np.copy(lc_cad)
        n_err_masked = np.copy(n_err)
        maskdata = np.atleast_2d(np.genfromtxt(maskfile, delimiter=','))
        mask = np.ones(len(lc_date_masked), dtype=bool)
        for maskrange in maskdata:
            if version == 1:
                start = maskrange[0] - 2400000.0
                end = maskrange[1] - 2400000.0
            elif version == 2:
                start = maskrange[0] - 2454833.
                end = maskrange[1] - 2454833.
            masknew = np.logical_xor(lc_date < start, lc_date > end)
            mask = np.logical_and(mask,masknew)
        lc_date_masked = lc_date_masked[mask]
        n_flux_masked = n_flux_masked[mask]
        lc_cad_masked = lc_cad_masked[mask]
        n_err_masked = n_err_masked[mask]
    else:
        lc_date_masked = np.copy(lc_date)
        n_flux_masked = np.copy(n_flux)
        lc_cad_masked = np.copy(lc_cad)
        n_err_masked = np.copy(n_err)

    bvectors_masked = get_pcomp_list_newformat(bvdata, bvlist, lc_cad_masked,
                                               short, scinterp)

    if iterate and sigma is None:
        errmsg = 'If fitting iteratively you must specify a clipping range'
        kepmsg.err(logfile, errmsg, verbose)

    #uses Pvals = yhat * U_transpose
    if iterate:
        coeffs, fittedmask = do_lst_iter(bvectors_masked, lc_cad_masked,
                                         n_flux_masked, sigma, 50., fitmethod,
                                         fitpower)
    else:
        if fitmethod == 'lst_sq':
            coeffs = do_lsq_nlin(bvectors_masked, n_flux_masked)
        elif fitmethod == 'simplex':
            coeffs = do_lsq_fmin_pow(bvectors_masked, n_flux_masked, fitpower)
        else:
            coeffs = do_lsq_uhat(bvectors_masked, n_flux_masked)

    coeffs = np.asarray(coeffs)
    flux_after = medflux * (n_flux + np.dot(coeffs.T, bvectors) + 1).reshape(-1)
    flux_after_masked = medflux * (n_flux_masked + np.dot(coeffs.T, bvectors_masked) + 1).reshape(-1)
    bvsum = np.dot(coeffs.T, bvectors).reshape(-1)
    bvsum_masked = np.dot(coeffs.T, bvectors_masked).reshape(-1)
    bvsum_nans = put_in_nans(good_data, bvsum)
    flux_after_nans = put_in_nans(good_data, flux_after)

    if plot:
        if not domasking:
            maskdata = None
        newmedflux = np.median(flux_after + 1)
        bvsum_un_norm = newmedflux * (1 - bvsum)
        do_plot(lc_date, lc_flux, flux_after, bvsum_un_norm, lc_cad,
                good_data, lc_cad_o, version, maskdata, outfile, noninteractive)

    print("Writing output file {}...".format(outfile))
    make_outfile(instr, outfile, flux_after_nans, bvsum_nans, version)
    # close input file
    instr.close()
    #print some results to screen:
    print('      -----      ')
    if iterate:
        flux_fit = n_flux_masked[fittedmask]
        sum_fit = bvsum_masked[fittedmask]
        err_fit = n_err_masked[fittedmask]
    else:
        flux_fit = n_flux_masked
        sum_fit = bvsum_masked
        err_fit = n_err_masked

    print('reduced chi2: {}'.format(chi2_gtf(flux_fit, sum_fit, err_fit,
                                             len(flux_fit) - len(coeffs))))
    print('rms: {}'.format(medflux * rms(flux_fit, sum_fit)))

    for i in range(len(coeffs)):
        print('Coefficient of CBV #{0}: {1}'.format(i + 1, coeffs[i]))
    print('      -----      ')

    # end time
    kepmsg.clock('KEPCOTREND completed at', logfile, verbose)
示例#24
0
    def __init__(
        self,
        fig,
        rect,
        nrows_ncols,
        ngrids=None,
        direction="row",
        axes_pad=0.02,
        add_all=True,
        share_all=False,
        share_x=True,
        share_y=True,
        # aspect=True,
        label_mode="L",
        axes_class=None,
    ):
        """
        Build an :class:`Grid` instance with a grid nrows*ncols
        :class:`~matplotlib.axes.Axes` in
        :class:`~matplotlib.figure.Figure` *fig* with
        *rect=[left, bottom, width, height]* (in
        :class:`~matplotlib.figure.Figure` coordinates) or
        the subplot position code (e.g., "121").

        Optional keyword arguments:

          ================  ========  =========================================
          Keyword           Default   Description
          ================  ========  =========================================
          direction         "row"     [ "row" | "column" ]
          axes_pad          0.02      float| pad between axes given in inches
          add_all           True      [ True | False ]
          share_all         False     [ True | False ]
          share_x           True      [ True | False ]
          share_y           True      [ True | False ]
          label_mode        "L"       [ "L" | "1" | "all" ]
          axes_class        None      a type object which must be a subclass
                                      of :class:`~matplotlib.axes.Axes`
          ================  ========  =========================================
        """
        self._nrows, self._ncols = nrows_ncols

        if ngrids is None:
            ngrids = self._nrows * self._ncols
        else:
            if (ngrids > self._nrows * self._ncols) or (ngrids <= 0):
                raise Exception("")

        self.ngrids = ngrids

        self._init_axes_pad(axes_pad)

        if direction not in ["column", "row"]:
            raise Exception("")

        self._direction = direction

        if axes_class is None:
            axes_class = self._defaultLocatableAxesClass
            axes_class_args = {}
        else:
            if (type(axes_class)) == type and issubclass(axes_class, self._defaultLocatableAxesClass.Axes):
                axes_class_args = {}
            else:
                axes_class, axes_class_args = axes_class

        self.axes_all = []
        self.axes_column = [[] for i in range(self._ncols)]
        self.axes_row = [[] for i in range(self._nrows)]

        h = []
        v = []
        if cbook.is_string_like(rect) or cbook.is_numlike(rect):
            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v, aspect=False)
        elif len(rect) == 3:
            kw = dict(horizontal=h, vertical=v, aspect=False)
            self._divider = SubplotDivider(fig, *rect, **kw)
        elif len(rect) == 4:
            self._divider = Divider(fig, rect, horizontal=h, vertical=v, aspect=False)
        else:
            raise Exception("")

        rect = self._divider.get_position()

        # reference axes
        self._column_refax = [None for i in range(self._ncols)]
        self._row_refax = [None for i in range(self._nrows)]
        self._refax = None

        for i in range(self.ngrids):

            col, row = self._get_col_row(i)

            if share_all:
                sharex = self._refax
                sharey = self._refax
            else:
                if share_x:
                    sharex = self._column_refax[col]
                else:
                    sharex = None

                if share_y:
                    sharey = self._row_refax[row]
                else:
                    sharey = None

            ax = axes_class(fig, rect, sharex=sharex, sharey=sharey, **axes_class_args)

            if share_all:
                if self._refax is None:
                    self._refax = ax
            else:
                if sharex is None:
                    self._column_refax[col] = ax
                if sharey is None:
                    self._row_refax[row] = ax

            self.axes_all.append(ax)
            self.axes_column[col].append(ax)
            self.axes_row[row].append(ax)

        self.axes_llc = self.axes_column[0][-1]

        self._update_locators()

        if add_all:
            for ax in self.axes_all:
                fig.add_axes(ax)

        self.set_label_mode(label_mode)
示例#25
0
def kepcotrendsc(infile,outfile,bvfile,listbv,fitmethod,fitpower,iterate,sigma,maskfile,scinterp,plot,clobber,verbose,logfile,
	status,cmdLine=False):
	"""
	Setup the kepcotrend environment

	infile:
	the input file in the FITS format obtained from MAST

	outfile:
	The output will be a fits file in the same style as the input file but with two additional columns: CBVSAP_MODL and CBVSAP_FLUX. The first of these is the best fitting linear combination of basis vectors. The second is the new flux with the basis vector sum subtracted. This is the new flux value.

	plot:
	either True or False if you want to see a plot of the light curve
	The top plot shows the original light curve in blue and the sum of basis vectors in red
	The bottom plot has had the basis vector sum subracted

	bvfile:
	the name of the FITS file containing the basis vectors

	listbv:
	the basis vectors to fit to the data

	fitmethod:
	fit using either the 'llsq' or the 'simplex' method. 'llsq' is usually the correct one to use because as the basis vectors are orthogonal. Simplex gives you option of using a different merit function - ie. you can minimise the least absolute residual instead of the least squares which weights outliers less

	fitpower:
	if using a simplex you can chose your own power in the metir function - i.e. the merit function minimises abs(Obs - Mod)^P. P=2 is least squares, P = 1 minimises least absolutes

	iterate:
	should the program fit the basis vectors to the light curve data then remove data points further than 'sigma' from the fit and then refit

	maskfile:
	this is the name of a mask file which can be used to define regions of the flux time series to exclude from the fit. The easiest way to create this is by using keprange from the PyKE set of tools. You can also make this yourself with two BJDs on each line in the file specifying the beginning and ending date of the region to exclude.

	scinterp:
	the basis vectors are only calculated for long cadence data, therefore if you want to use short cadence data you have to interpolate the basis vectors. There are several methods to do this, the best of these probably being nearest which picks the value of the nearest long cadence data point.
	The options available are None|linear|nearest|zero|slinear|quadratic|cubic
	If you are using short cadence data don't choose none
	"""
	# log the call
	hashline = '----------------------------------------------------------------------------'
	kepmsg.log(logfile,hashline,verbose)
	call = 'KEPCOTREND -- '
	call += 'infile='+infile+' '
	call += 'outfile='+outfile+' '
	call += 'bvfile='+bvfile+' '
#	call += 'numpcomp= '+str(numpcomp)+' '
	call += 'listbv= '+str(listbv)+' '
	call += 'fitmethod=' +str(fitmethod)+ ' '
	call += 'fitpower=' + str(fitpower)+ ' '
	iterateit = 'n'
	if (iterate): iterateit = 'y'
	call += 'iterate='+iterateit+ ' '
	call += 'sigma_clip='+str(sigma)+' '
	call += 'mask_file='+maskfile+' '
	call += 'scinterp=' + str(scinterp)+ ' '
	plotit = 'n'
	if (plot): plotit = 'y'
	call += 'plot='+plotit+ ' '
	overwrite = 'n'
	if (clobber): overwrite = 'y'
	call += 'clobber='+overwrite+ ' '
	chatter = 'n'
	if (verbose): chatter = 'y'
	call += 'verbose='+chatter+' '
	call += 'logfile='+logfile
	kepmsg.log(logfile,call+'\n',verbose)

	# start time
	kepmsg.clock('KEPCOTREND started at',logfile,verbose)

	# test log file
	logfile = kepmsg.test(logfile)

	# clobber output file
	if clobber:
		status = kepio.clobber(outfile,logfile,verbose)
	if kepio.fileexists(outfile):
		message = 'ERROR -- KEPCOTREND: ' + outfile + ' exists. Use --clobber'
		status = kepmsg.err(logfile,message,verbose)

	# open input file
	if status == 0:
		instr, status = kepio.openfits(infile,'readonly',logfile,verbose)
		tstart, tstop, bjdref, cadence, status = kepio.timekeys(instr,
			infile,logfile,verbose,status)

	# fudge non-compliant FITS keywords with no values
	if status == 0:
		instr = kepkey.emptykeys(instr,file,logfile,verbose)

	if status == 0:
		if not kepio.fileexists(bvfile):
			message = 'ERROR -- KEPCOTREND: ' + bvfile + ' does not exist.'
			status = kepmsg.err(logfile,message,verbose)

	#lsq_sq - nonlinear least squares fitting and simplex_abs have been
	#removed from the options in PyRAF but they are still in the code!
	if status == 0:
		if fitmethod not in ['llsq','matrix','lst_sq','simplex_abs','simplex']:
			message = 'Fit method must either: llsq, matrix, lst_sq or simplex'
			status = kepmsg.err(logfile,message,verbose)

	if status == 0:
		if not is_numlike(fitpower) and fitpower is not None:
			message = 'Fit power must be an real number or None'
			status = kepmsg.err(logfile,message,verbose)



	if status == 0:
		if fitpower is None:
			fitpower = 1.

	# input data
	if status == 0:
		short = False
		try:
			test = str(instr[0].header['FILEVER'])
			version = 2
		except KeyError:
			version = 1

		table = instr[1].data
		if version == 1:
			if str(instr[1].header['DATATYPE']) == 'long cadence':
				#print 'Light curve was taken in Lond Cadence mode!'
				quarter = str(instr[1].header['QUARTER'])
				module = str(instr[1].header['MODULE'])
				output = str(instr[1].header['OUTPUT'])
				channel = str(instr[1].header['CHANNEL'])

				lc_cad_o = table.field('cadence_number')
				lc_date_o = table.field('barytime')
				lc_flux_o = table.field('ap_raw_flux') / 1625.3468 #convert to e-/s
				lc_err_o = table.field('ap_raw_err') / 1625.3468 #convert to e-/s
			elif str(instr[1].header['DATATYPE']) == 'short cadence':
				short = True
				#print 'Light curve was taken in Short Cadence mode!'
				quarter = str(instr[1].header['QUARTER'])
				module = str(instr[1].header['MODULE'])
				output = str(instr[1].header['OUTPUT'])
				channel = str(instr[1].header['CHANNEL'])

				lc_cad_o = table.field('cadence_number')
				lc_date_o = table.field('barytime')
				lc_flux_o = table.field('ap_raw_flux') / 54.178 #convert to e-/s
				lc_err_o = table.field('ap_raw_err') / 54.178 #convert to e-/s

		elif version >= 2:
			if str(instr[0].header['OBSMODE']) == 'long cadence':
				#print 'Light curve was taken in Long Cadence mode!'

				quarter = str(instr[0].header['QUARTER'])
				module = str(instr[0].header['MODULE'])
				output = str(instr[0].header['OUTPUT'])
				channel = str(instr[0].header['CHANNEL'])

				lc_cad_o = table.field('CADENCENO')
				lc_date_o = table.field('TIME')
				lc_flux_o = table.field('SAP_FLUX')
				lc_err_o = table.field('SAP_FLUX_ERR')
			elif str(instr[0].header['OBSMODE']) == 'short cadence':
				#print 'Light curve was taken in Short Cadence mode!'
				short = True
				quarter = str(instr[0].header['QUARTER'])
				module = str(instr[0].header['MODULE'])
				output = str(instr[0].header['OUTPUT'])
				channel = str(instr[0].header['CHANNEL'])

				lc_cad_o = table.field('CADENCENO')
				lc_date_o = table.field('TIME')
				lc_flux_o = table.field('SAP_FLUX')
				lc_err_o = table.field('SAP_FLUX_ERR')


		if str(quarter) == str(4) and version == 1:
			lc_cad_o = lc_cad_o[lc_cad_o >= 11914]
			lc_date_o = lc_date_o[lc_cad_o >= 11914]
			lc_flux_o = lc_flux_o[lc_cad_o >= 11914]
			lc_err_o = lc_err_o[lc_cad_o >= 11914]

		# bvfilename = '%s/Q%s_%s_%s_map.txt' %(bvfile,quarter,module,output)
		# if str(quarter) == str(5):
		# 	bvdata = genfromtxt(bvfilename)
		# elif str(quarter) == str(3) or str(quarter) == str(4):
		# 	bvdata = genfromtxt(bvfilename,skip_header=22)
		# elif str(quarter) == str(1):
		# 	bvdata = genfromtxt(bvfilename,skip_header=10)
		# else:
		# 	bvdata = genfromtxt(bvfilename,skip_header=13)

		if short and scinterp == 'None':
			message = 'You cannot select None as the interpolation method because you are using short cadence data and therefore must use some form of interpolation. I reccommend nearest if you are unsure.'
			status = kepmsg.err(logfile,message,verbose)

		bvfiledata = pyfits.open(bvfile)
		bvdata = bvfiledata['MODOUT_%s_%s' %(module,output)].data


		if int(bvfiledata[0].header['QUARTER']) != int(quarter):
			message = 'CBV file and light curve file are from different quarters. CBV file is from Q%s and light curve is from Q%s' %(int(bvfiledata[0].header['QUARTER']),int(quarter))
			status = kepmsg.err(logfile,message,verbose)

	if status == 0:
		if int(quarter) == 4 and int(module) == 3:
			message = 'Approximately twenty days into Q4 Module 3 failed. As a result, Q4 light curves contain these 20 day of data. However, we do not calculate CBVs for this section of data.'
			status = kepmsg.err(logfile,message,verbose)

	if status == 0:


		#cut out infinites and zero flux columns
		lc_cad,lc_date,lc_flux,lc_err,bad_data = cutBadData(lc_cad_o,
			lc_date_o,lc_flux_o,lc_err_o)

		#get a list of basis vectors to use from the list given
		#accept different seperators
		listbv = listbv.strip()
		if listbv[1] in [' ',',',':',';','|',', ']:
			separator = str(listbv)[1]
		else:
			message = 'You must separate your basis vector numbers to use with \' \' \',\' \':\' \';\' or \'|\' and the first basis vector to use must be between 1 and 9'
			status = kepmsg.err(logfile,message,verbose)


	if status == 0:
		bvlist = fromstring(listbv,dtype=int,sep=separator)

		if bvlist[0] == 0:
			message = 'Must use at least one basis vector'
			status = kepmsg.err(logfile,message,verbose)
	if status == 0:
		#pcomps = get_pcomp(pcompdata,n_comps,lc_cad)
		# if str(quarter) == str(5):
		# 	bvectors = get_pcomp_list(bvdata,bvlist,lc_cad)
		# else:
		#	bvectors = get_pcomp_list_newformat(bvdata,bvlist,lc_cad)

		if short:
			bvdata.field('CADENCENO')[:] = (((bvdata.field('CADENCENO')[:] + (7.5/15.) )* 30.) - 11540.).round()

		bvectors,in1derror = get_pcomp_list_newformat(bvdata,bvlist,lc_cad,short,scinterp)

		if in1derror:
			message = 'It seems that you have an old version of numpy which does not have the in1d function included. Please update your version of numpy to a version 1.4.0 or later'
			status = kepmsg.err(logfile,message,verbose)
	if status == 0:

		medflux = median(lc_flux)
		n_flux = (lc_flux /medflux)-1
		n_err = sqrt(pow(lc_err,2)/ pow(medflux,2))

		#plt.errorbar(lc_cad,n_flux,yerr=n_err)
		#plt.errorbar(lc_cad,lc_flux,yerr=lc_err)

		#n_err = median(lc_err/lc_flux) * n_flux
		#print n_err

		#does an iterative least squares fit
		#t1 = do_leastsq(pcomps,lc_cad,n_flux)
		#

		if maskfile != '':
			domasking = True
			if not kepio.fileexists(maskfile):
				message = 'Maskfile %s does not exist' %maskfile
				status = kepmsg.err(logfile,message,verbose)
		else:
			domasking = False



	if status == 0:
		if domasking:

			lc_date_masked = copy(lc_date)
			n_flux_masked = copy(n_flux)
			lc_cad_masked = copy(lc_cad)
			n_err_masked = copy(n_err)
			maskdata = atleast_2d(genfromtxt(maskfile,delimiter=','))
			#make a mask of True values incase there are not regions in maskfile to exclude.
			mask = zeros(len(lc_date_masked)) == 0.
			for maskrange in maskdata:
				if version == 1:
					start = maskrange[0] - 2400000.0
					end = maskrange[1] - 2400000.0
				elif version == 2:
					start = maskrange[0] - 2454833.
					end = maskrange[1] - 2454833.
				masknew = logical_xor(lc_date < start,lc_date > end)
				mask = logical_and(mask,masknew)

			lc_date_masked = lc_date_masked[mask]
			n_flux_masked = n_flux_masked[mask]
			lc_cad_masked = lc_cad_masked[mask]
			n_err_masked = n_err_masked[mask]
		else:
			lc_date_masked = copy(lc_date)
			n_flux_masked = copy(n_flux)
			lc_cad_masked = copy(lc_cad)
			n_err_masked = copy(n_err)


		#pcomps = get_pcomp(pcompdata,n_comps,lc_cad)

		bvectors_masked,hasin1d = get_pcomp_list_newformat(bvdata,bvlist,lc_cad_masked,short,scinterp)


		if (iterate) and sigma is None:
			message = 'If fitting iteratively you must specify a clipping range'
			status = kepmsg.err(logfile,message,verbose)

	if status == 0:
		#uses Pvals = yhat * U_transpose
		if (iterate):
			coeffs,fittedmask = do_lst_iter(bvectors_masked,lc_cad_masked
				,n_flux_masked,sigma,50.,fitmethod,fitpower)
		else:
			if fitmethod == 'matrix' and domasking:
				coeffs = do_lsq_uhat(bvectors_masked,lc_cad_masked,n_flux_masked,False)
			if fitmethod == 'llsq' and domasking:
				coeffs = do_lsq_uhat(bvectors_masked,lc_cad_masked,n_flux_masked,False)
			elif fitmethod == 'lst_sq':
				coeffs = do_lsq_nlin(bvectors_masked,lc_cad_masked,n_flux_masked)
			elif fitmethod == 'simplex_abs':
				coeffs = do_lsq_fmin(bvectors_masked,lc_cad_masked,n_flux_masked)
			elif fitmethod == 'simplex':
				coeffs = do_lsq_fmin_pow(bvectors_masked,lc_cad_masked,n_flux_masked,fitpower)
			else:
				coeffs = do_lsq_uhat(bvectors_masked,lc_cad_masked,n_flux_masked)



		flux_after = (get_newflux(n_flux,bvectors,coeffs) +1) * medflux
		flux_after_masked = (get_newflux(n_flux_masked,bvectors_masked,coeffs) +1) * medflux
		bvsum = get_pcompsum(bvectors,coeffs)

		bvsum_masked =  get_pcompsum(bvectors_masked,coeffs)

		#print 'chi2: ' + str(chi2_gtf(n_flux,bvsum,n_err,2.*len(n_flux)-2))
		#print 'rms: ' + str(rms(n_flux,bvsum))


		bvsum_nans = putInNans(bad_data,bvsum)
		flux_after_nans = putInNans(bad_data,flux_after)


	if plot and status == 0:
         newmedflux = median(flux_after + 1)
         bvsum_un_norm = newmedflux*(1-bvsum)
         #bvsum_un_norm = 0-bvsum
         #lc_flux = n_flux
         do_plot(lc_date,lc_flux,flux_after,
			bvsum_un_norm,lc_cad,bad_data,lc_cad_o,version,cmdLine)

	if status== 0:
		make_outfile(instr,outfile,flux_after_nans,bvsum_nans,version)

	# close input file
	if status == 0:
		status = kepio.closefits(instr,logfile,verbose)

		#print some results to screen:
		print '      -----      '
		if iterate:
			flux_fit = n_flux_masked[fittedmask]
			sum_fit = bvsum_masked[fittedmask]
			err_fit = n_err_masked[fittedmask]
		else:
			flux_fit = n_flux_masked
			sum_fit = bvsum_masked
			err_fit = n_err_masked
		print 'reduced chi2: ' + str(chi2_gtf(flux_fit,sum_fit,err_fit,len(flux_fit)-len(coeffs)))
		print 'rms: ' + str(medflux*rms(flux_fit,sum_fit))
		for i in range(len(coeffs)):
			print 'Coefficient of CBV #%s: %s' %(i+1,coeffs[i])
		print '      -----      '


	# end time
	if (status == 0):
		message = 'KEPCOTREND completed at'
	else:
		message = '\nKEPCOTTREND aborted at'
	kepmsg.clock(message,logfile,verbose)

	return
示例#26
0
def find_best_d(ftsfile,d_plt_ini,center=None,caltype=None,obj=None):
# ctrxy: array([x,y]) - pixel location of the object center

    ssw=ftsfile['SSWD4']
    slw=ftsfile['SLWC3']
    sswidx=where(ssw.data['wave'] < fmax_slw)
    sswidx=sswidx[0][1:]
    ssw=ssw.data[sswidx]
    ssw=ssw
    slwidx=where(slw.data['wave'] > fmin_ssw)
    slwidx=slwidx[0][0:-1]
    slw=slw.data[slwidx]

    if not is_numlike(center):
        ctrxy=array([128,128])
    else:
        hdr=ftsfile['SLWC3'].header
        dummywcs=make_dummy_wcs(hdr['RA'],hdr['DEC'])
        xyradec=array([center])
        ctrxy_tmp=dummywcs.wcs_world2pix(xyradec,0)
        ctrxy=copy(ctrxy_tmp[0])

    xs=ssw['wave']
    xl=slw['wave']

    if caltype == 'point':
        ssw_cal=zeros(len(cps['SSWD4'].data['pointConv']))+1.
        slw_cal=zeros(len(cps['SLWC3'].data['pointConv']))+1.
        
    if caltype == 'extended':
        ssw_cal=cps['SSWD4'].data['pointConv'].copy()
        slw_cal=cps['SLWC3'].data['pointConv'].copy()

    ssw_cal=ssw_cal[sswidx]
    slw_cal=slw_cal[slwidx]

    if (obj == 'm83' or obj == 'M83') or \
            (obj == 'm82' or obj == 'M82') or \
            (obj == 'lmc-n159' or obj == 'LMC-N159'):
        scal=cps['SSWD4'].data['pointConv'].copy()
        lcal=cps['SLWC3'].data['pointConv'].copy()
        scal=scal[sswidx]
        lcal=lcal[slwidx]
        ssw['flux'][:]=ssw['flux'].copy()*scal
        ssw['error'][:]=ssw['error'].copy()*scal
        slw['flux'][:]=slw['flux'].copy()*lcal
        slw['error'][:]=slw['error'].copy()*lcal
        sim_func=sim_gauss
    else:
        sim_func=sim_planet


    ssw_wnidx=(wn_ssw[0].data*30. < fmax_slw+5.) & \
        (wn_ssw[0].data*30. > fmin_ssw-5.)
    slw_wnidx=(wn_slw[0].data*30. < fmax_slw+5.) & \
        (wn_slw[0].data*30. > fmin_ssw-5.)
    slw_wnidx[-7]=False
    wn_sfit=wn_ssw[0].data[ssw_wnidx]
    wn_lfit=wn_slw[0].data[slw_wnidx]

    ssw_b=beam_ssw[0].data[ssw_wnidx,:,:].copy()
    slw_b=beam_slw[0].data[slw_wnidx,:,:].copy()
    sum_ssw_b=sum(sum(ssw_b,axis=1),axis=1)
    sum_slw_b=sum(sum(slw_b,axis=1),axis=1)

    refidx=1e9
    chiparam=[]
    chierr=[]
    d_plt_out=None
    d_plt=arange(26)*2.
    dplt=d_plt_ini+d_plt-20.
    dplt=dplt[where(dplt >= 0.)]
    for di in range(len(dplt)):
        d_input=dplt[di]
        planet_mod=sim_func(d_input,ctrxy)*img_mask
        planet_mod=planet_mod.copy()/planet_mod.max()
        planet_area=sum(planet_mod)
        sum_ssw_bp=[]
        sum_slw_bp=[]
        for bi in range(len(sum_ssw_b)):
            sum_ssw_bp.append(sum(ssw_b[bi,:,:]*planet_mod))
        for bi in range(len(sum_slw_b)):
            sum_slw_bp.append(sum(slw_b[bi,:,:]*planet_mod))
        sum_ssw_bp=array(sum_ssw_bp)
        sum_slw_bp=array(sum_slw_bp)
        f_sum_sbp=interpolate.interp1d(wn_sfit*30.,planet_area/sum_ssw_bp, \
                                           bounds_error=False, kind=3, \
                                           fill_value=planet_area/sum_ssw_bp[0])
        f_sum_lbp=interpolate.interp1d(wn_lfit*30.,planet_area/sum_slw_bp, \
                                           bounds_error=False, kind=3, \
                                           fill_value=planet_area/sum_slw_bp[0])
        ssw_corrf=ssw['flux']*ssw_cal*f_sum_sbp(xs)
        slw_corrf=slw['flux']*slw_cal*f_sum_lbp(xl)
        param=sum((slw_corrf-ssw_corrf)**2./100./ \
                      2./((slw['error']*slw_cal*f_sum_lbp(xl))**2.+ \
                              (ssw['error']*ssw_cal*f_sum_sbp(xs))**2.))
        err=sqrt(sum((slw['error']*slw_cal*f_sum_lbp(xl))**2.+ \
                    (ssw['error']*ssw_cal*f_sum_sbp(xs))**2.))
        chiparam.append(param)
        chierr.append(err)
        if param < refidx:
            refidx=param

    chiparam=array(chiparam)
    chierr=array(chierr)
    chiidx=(chiparam == chiparam.min())

    return chiparam,chierr,dplt
示例#27
0
    def compute_aesthetics(self, plot):
        """
        Return a dataframe where the columns match the
        aesthetic mappings.

        Transformations like 'factor(cyl)' and other
        expression evaluation are  made in here
        """
        data = self.data
        aesthetics = self.layer_mapping(plot.mapping)

        # Override grouping if set in layer.
        with suppress(KeyError):
            aesthetics['group'] = self.geom.aes_params['group']

        env = EvalEnvironment.capture(eval_env=plot.environment)
        env = env.with_outer_namespace({'factor': pd.Categorical})

        # Using `type` preserves the subclass of pd.DataFrame
        evaled = type(data)(index=data.index)

        # If a column name is not in the data, it is evaluated/transformed
        # in the environment of the call to ggplot
        for ae, col in aesthetics.items():
            if isinstance(col, six.string_types):
                if col in data:
                    evaled[ae] = data[col]
                else:
                    try:
                        new_val = env.eval(col, inner_namespace=data)
                    except Exception as e:
                        raise PlotnineError(
                            _TPL_EVAL_FAIL.format(ae, col, str(e)))

                    try:
                        evaled[ae] = new_val
                    except Exception as e:
                        raise PlotnineError(
                            _TPL_BAD_EVAL_TYPE.format(
                                ae, col, str(type(new_val)), str(e)))
            elif pdtypes.is_list_like(col):
                n = len(col)
                if len(data) and n != len(data) and n != 1:
                    raise PlotnineError(
                        "Aesthetics must either be length one, " +
                        "or the same length as the data")
                # An empty dataframe does not admit a scalar value
                elif len(evaled) and n == 1:
                    col = col[0]
                evaled[ae] = col
            elif not cbook.iterable(col) and cbook.is_numlike(col):
                # An empty dataframe does not admit a scalar value
                if not len(evaled):
                    col = [col]
                evaled[ae] = col
            else:
                msg = "Do not know how to deal with aesthetic '{}'"
                raise PlotnineError(msg.format(ae))

        evaled_aes = aes(**dict((col, col) for col in evaled))
        plot.scales.add_defaults(evaled, evaled_aes)

        if len(data) == 0 and len(evaled) > 0:
            # No data, and vectors suppled to aesthetics
            evaled['PANEL'] = 1
        else:
            evaled['PANEL'] = data['PANEL']

        self.data = add_group(evaled)
示例#28
0
    def _calculate(self, data):
        x = data.pop('x')
        right = self.params['right']

        # y values are not needed
        try:
            del data['y']
        except KeyError:
            pass
        else:
            self._print_warning(_MSG_YVALUE)

        if len(x) > 0 and isinstance(x.get(0), datetime.date):
            def convert(d):
                d = datetime.datetime.combine(d, datetime.datetime.min.time())
                return time.mktime(d.timetuple())
            x = x.apply(convert)
        elif len(x) > 0 and isinstance(x.get(0), datetime.datetime):
            x = x.apply(lambda d: time.mktime(d.timetuple()))
        elif len(x) > 0 and isinstance(x.get(0), datetime.time):
            raise GgplotError("Cannot recognise the type of x")

        # If weight not mapped to, use one (no weight)
        try:
            weights = data.pop('weight')
        except KeyError:
            weights = np.ones(len(x))
        else:
            weights = make_iterable_ntimes(weights, len(x))

        if is_categorical(x.values):
            x_assignments = x
            x = self.labels
            width = make_iterable_ntimes(self.params['width'], self.length)
        elif cbook.is_numlike(x.iloc[0]):
            x_assignments = pd.cut(x, bins=self.breaks, labels=False,
                                           right=right)
            width = np.diff(self.breaks)
            x = [self.breaks[i] + width[i] / 2
                 for i in range(len(self.breaks)-1)]
        else:
            raise GgplotError("Cannot recognise the type of x")

        # Create a dataframe with two columns:
        #   - the bins to which each x is assigned
        #   - the weights of each x value
        # Then create a weighted frequency table
        _df = pd.DataFrame({'assignments': x_assignments,
                            'weights': weights
                            })
        _wfreq_table = pd.pivot_table(_df, values='weights',
                                      rows=['assignments'], aggfunc=np.sum)

        # For numerical x values, empty bins get have no value
        # in the computed frequency table. We need to add the zeros and
        # since frequency table is a Series object, we need to keep it ordered
        try:
            empty_bins = set(self.labels) - set(x_assignments)
        except:
            empty_bins = set(range(len(width))) - set(x_assignments)
        _wfreq_table = _wfreq_table.to_dict()
        for _b in empty_bins:
            _wfreq_table[_b] = 0
        _wfreq_table = pd.Series(_wfreq_table).sort_index()

        y = list(_wfreq_table)
        new_data = pd.DataFrame({'x': x, 'y': y, 'width': width})

        # Copy the other aesthetics into the new dataframe
        n = len(x)
        for ae in data:
            new_data[ae] = make_iterable_ntimes(data[ae].iloc[0], n)
        return new_data
示例#29
0
def hist(self, x, bins=10, range=None, normed=False, weights=None,
         cumulative=False, bottom=None, histtype='bar', align='mid',
         orientation='vertical', rwidth=None, log=False,
         color=None, label=None,
         **kwargs):
    """
    call signature::

      hist(x, bins=10, range=None, normed=False, cumulative=False,
           bottom=None, histtype='bar', align='mid',
           orientation='vertical', rwidth=None, log=False, **kwargs)

    Compute and draw the histogram of *x*. The return value is a
    tuple (*n*, *bins*, *patches*) or ([*n0*, *n1*, ...], *bins*,
    [*patches0*, *patches1*,...]) if the input contains multiple
    data.

    Multiple data can be provided via *x* as a list of datasets
    of potentially different length ([*x0*, *x1*, ...]), or as
    a 2-D ndarray in which each column is a dataset.  Note that
    the ndarray form is transposed relative to the list form.

    Masked arrays are not supported at present.

    Keyword arguments:

      *bins*:
        Either an integer number of bins or a sequence giving the
        bins.  If *bins* is an integer, *bins* + 1 bin edges
        will be returned, consistent with :func:`numpy.histogram`
        for numpy version >= 1.3, and with the *new* = True argument
        in earlier versions.
        Unequally spaced bins are supported if *bins* is a sequence.

      *range*:
        The lower and upper range of the bins. Lower and upper outliers
        are ignored. If not provided, *range* is (x.min(), x.max()).
        Range has no effect if *bins* is a sequence.

        If *bins* is a sequence or *range* is specified, autoscaling
        is based on the specified bin range instead of the
        range of x.

      *normed*:
        If *True*, the first element of the return tuple will
        be the counts normalized to form a probability density, i.e.,
        ``n/(len(x)*dbin)``.  In a probability density, the integral of
        the histogram should be 1; you can verify that with a
        trapezoidal integration of the probability density function::

          pdf, bins, patches = ax.hist(...)
          print np.sum(pdf * np.diff(bins))

        .. Note:: Until numpy release 1.5, the underlying numpy
                  histogram function was incorrect with *normed*=*True*
                  if bin sizes were unequal.  MPL inherited that
                  error.  It is now corrected within MPL when using
                  earlier numpy versions

      *weights*
        An array of weights, of the same shape as *x*.  Each value in
        *x* only contributes its associated weight towards the bin
        count (instead of 1).  If *normed* is True, the weights are
        normalized, so that the integral of the density over the range
        remains 1.

      *cumulative*:
        If *True*, then a histogram is computed where each bin
        gives the counts in that bin plus all bins for smaller values.
        The last bin gives the total number of datapoints.  If *normed*
        is also *True* then the histogram is normalized such that the
        last bin equals 1. If *cumulative* evaluates to less than 0
        (e.g. -1), the direction of accumulation is reversed.  In this
        case, if *normed* is also *True*, then the histogram is normalized
        such that the first bin equals 1.

      *histtype*: [ 'bar' | 'barstacked' | 'step' | 'stepfilled' ]
        The type of histogram to draw.

          - 'bar' is a traditional bar-type histogram.  If multiple data
            are given the bars are aranged side by side.

          - 'barstacked' is a bar-type histogram where multiple
            data are stacked on top of each other.

          - 'step' generates a lineplot that is by default
            unfilled.

          - 'stepfilled' generates a lineplot that is by default
            filled.

      *align*: ['left' | 'mid' | 'right' ]
        Controls how the histogram is plotted.

          - 'left': bars are centered on the left bin edges.

          - 'mid': bars are centered between the bin edges.

          - 'right': bars are centered on the right bin edges.

      *orientation*: [ 'horizontal' | 'vertical' ]
        If 'horizontal', :func:`~matplotlib.pyplot.barh` will be
        used for bar-type histograms and the *bottom* kwarg will be
        the left edges.

      *rwidth*:
        The relative width of the bars as a fraction of the bin
        width.  If *None*, automatically compute the width. Ignored
        if *histtype* = 'step' or 'stepfilled'.

      *log*:
        If *True*, the histogram axis will be set to a log scale.
        If *log* is *True* and *x* is a 1D array, empty bins will
        be filtered out and only the non-empty (*n*, *bins*,
        *patches*) will be returned.

      *color*:
        Color spec or sequence of color specs, one per
        dataset.  Default (*None*) uses the standard line
        color sequence.

      *label*:
        String, or sequence of strings to match multiple
        datasets.  Bar charts yield multiple patches per
        dataset, but only the first gets the label, so
        that the legend command will work as expected::

            ax.hist(10+2*np.random.randn(1000), label='men')
            ax.hist(12+3*np.random.randn(1000), label='women', alpha=0.5)
            ax.legend()

    kwargs are used to update the properties of the
    :class:`~matplotlib.patches.Patch` instances returned by *hist*:

    %(Patch)s

    **Example:**

    .. plot:: mpl_examples/pylab_examples/histogram_demo.py
    """
    if not self._hold: self.cla()

    # NOTE: the range keyword overwrites the built-in func range !!!
    #       needs to be fixed in numpy                           !!!

    # Validate string inputs here so we don't have to clutter
    # subsequent code.
    if histtype not in ['bar', 'barstacked', 'step', 'stepfilled']:
        raise ValueError("histtype %s is not recognized" % histtype)

    if align not in ['left', 'mid', 'right']:
        raise ValueError("align kwarg %s is not recognized" % align)

    if orientation not in [ 'horizontal', 'vertical']:
        raise ValueError(
            "orientation kwarg %s is not recognized" % orientation)


    if kwargs.get('width') is not None:
        raise DeprecationWarning(
            'hist now uses the rwidth to give relative width '
            'and not absolute width')

    # Massage 'x' for processing.
    # NOTE: Be sure any changes here is also done below to 'weights'
    if isinstance(x, np.ndarray) or not iterable(x[0]):
        # TODO: support masked arrays;
        x = np.asarray(x)
        if x.ndim == 2:
            x = x.T # 2-D input with columns as datasets; switch to rows
        elif x.ndim == 1:
            x = x.reshape(1, x.shape[0])  # new view, single row
        else:
            raise ValueError("x must be 1D or 2D")
        if x.shape[1] < x.shape[0]:
            warnings.warn('2D hist input should be nsamples x nvariables;\n '
                'this looks transposed (shape is %d x %d)' % x.shape[::-1])
    else:
        # multiple hist with data of different length
        x = [np.array(xi) for xi in x]

    nx = len(x) # number of datasets

    if color is None:
        color = [next(self._get_lines.color_cycle)
                                        for i in range(nx)]
    else:
        color = mcolors.colorConverter.to_rgba_array(color)
        if len(color) != nx:
            raise ValueError("color kwarg must have one color per dataset")

    # We need to do to 'weights' what was done to 'x'
    if weights is not None:
        if isinstance(weights, np.ndarray) or not iterable(weights[0]) :
            w = np.array(weights)
            if w.ndim == 2:
                w = w.T
            elif w.ndim == 1:
                w.shape = (1, w.shape[0])
            else:
                raise ValueError("weights must be 1D or 2D")
        else:
            w = [np.array(wi) for wi in weights]

        if len(w) != nx:
            raise ValueError('weights should have the same shape as x')
        for i in range(nx):
            if len(w[i]) != len(x[i]):
                raise ValueError(
                    'weights should have the same shape as x')
    else:
        w = [None]*nx


    # Save autoscale state for later restoration; turn autoscaling
    # off so we can do it all a single time at the end, instead
    # of having it done by bar or fill and then having to be redone.
    _saved_autoscalex = self.get_autoscalex_on()
    _saved_autoscaley = self.get_autoscaley_on()
    self.set_autoscalex_on(False)
    self.set_autoscaley_on(False)

    # Save the datalimits for the same reason:
    _saved_bounds = self.dataLim.bounds

    # Check whether bins or range are given explicitly. In that
    # case use those values for autoscaling.
    binsgiven = (cbook.iterable(bins) or range != None)

    # If bins are not specified either explicitly or via range,
    # we need to figure out the range required for all datasets,
    # and supply that to np.histogram.
    if not binsgiven:
        xmin = np.inf
        xmax = -np.inf
        for xi in x:
            xmin = min(xmin, xi.min())
            xmax = max(xmax, xi.max())
        range = (xmin, xmax)

    #hist_kwargs = dict(range=range, normed=bool(normed))
    # We will handle the normed kwarg within mpl until we
    # get to the point of requiring numpy >= 1.5.
    hist_kwargs = dict(range=range)
    if np.__version__ < "1.3": # version 1.1 and 1.2
        hist_kwargs['new'] = True

    n = []
    for i in range(nx):
        # this will automatically overwrite bins,
        # so that each histogram uses the same bins
        m, bins = np.histogram(x[i], bins, weights=w[i], **hist_kwargs)
        if normed:
            db = np.diff(bins)
            m = (m.astype(float) / db) / m.sum()
        n.append(m)
    if normed and db.std() > 0.01 * db.mean():
        warnings.warn("""
        This release fixes a normalization bug in the NumPy histogram
        function prior to version 1.5, occuring with non-uniform
        bin widths. The returned and plotted value is now a density:
            n / (N * bin width),
        where n is the bin count and N the total number of points.
        """)



    if cumulative:
        slc = slice(None)
        if cbook.is_numlike(cumulative) and cumulative < 0:
            slc = slice(None,None,-1)

        if normed:
            n = [(m * np.diff(bins))[slc].cumsum()[slc] for m in n]
        else:
            n = [m[slc].cumsum()[slc] for m in n]

    patches = []

    if histtype.startswith('bar'):
        totwidth = np.diff(bins)

        if rwidth is not None:
            dr = min(1.0, max(0.0, rwidth))
        elif len(n)>1:
            dr = 0.8
        else:
            dr = 1.0

        if histtype=='bar':
            width = dr*totwidth/nx
            dw = width

            if nx > 1:
                boffset = -0.5*dr*totwidth*(1.0-1.0/nx)
            else:
                boffset = 0.0
            stacked = False
        elif histtype=='barstacked':
            width = dr*totwidth
            boffset, dw = 0.0, 0.0
            stacked = True

        if align == 'mid' or align == 'edge':
            boffset += 0.5*totwidth
        elif align == 'right':
            boffset += totwidth

        if orientation == 'horizontal':
            _barfunc = self.barh
        else:  # orientation == 'vertical'
            _barfunc = self.bar

        for m, c in zip(n, color):
            patch = _barfunc(bins[:-1]+boffset, m, width, bottom,
                              align='center', log=log,
                              color=c)
            patches.append(patch)
            if stacked:
                if bottom is None:
                    bottom = 0.0
                bottom += m
            boffset += dw

    elif histtype.startswith('step'):
        x = np.zeros( 2*len(bins), np.float )
        y = np.zeros( 2*len(bins), np.float )

        x[0::2], x[1::2] = bins, bins

        # FIX FIX FIX
        # This is the only real change.
        # minimum = min(bins)
        if log is True:
            minimum = 1.0
        elif log:
            minimum = float(log)
        else:
            minimum = 0.0
        # FIX FIX FIX end

        if align == 'left' or align == 'center':
            x -= 0.5*(bins[1]-bins[0])
        elif align == 'right':
            x += 0.5*(bins[1]-bins[0])

        if log:
            y[0],y[-1] = minimum, minimum
            if orientation == 'horizontal':
                self.set_xscale('log')
            else:  # orientation == 'vertical'
                self.set_yscale('log')

        fill = (histtype == 'stepfilled')

        for m, c in zip(n, color):
            y[1:-1:2], y[2::2] = m, m
            if log:
                y[y<minimum]=minimum
            if orientation == 'horizontal':
                x,y = y,x

            if fill:
                patches.append( self.fill(x, y,
                    closed=False, facecolor=c) )
            else:
                patches.append( self.fill(x, y,
                    closed=False, edgecolor=c, fill=False) )

        # adopted from adjust_x/ylim part of the bar method
        if orientation == 'horizontal':
            xmin0 = max(_saved_bounds[0]*0.9, minimum)
            xmax = self.dataLim.intervalx[1]
            for m in n:
                xmin = np.amin(m[m!=0]) # filter out the 0 height bins
            xmin = max(xmin*0.9, minimum)
            xmin = min(xmin0, xmin)
            self.dataLim.intervalx = (xmin, xmax)
        elif orientation == 'vertical':
            ymin0 = max(_saved_bounds[1]*0.9, minimum)
            ymax = self.dataLim.intervaly[1]
            for m in n:
                ymin = np.amin(m[m!=0]) # filter out the 0 height bins
            ymin = max(ymin*0.9, minimum)
            ymin = min(ymin0, ymin)
            self.dataLim.intervaly = (ymin, ymax)

    if label is None:
        labels = ['_nolegend_']
    elif is_string_like(label):
        labels = [label]
    elif is_sequence_of_strings(label):
        labels = list(label)
    else:
        raise ValueError(
            'invalid label: must be string or sequence of strings')
    if len(labels) < nx:
        labels += ['_nolegend_'] * (nx - len(labels))

    for (patch, lbl) in zip(patches, labels):
        for p in patch:
            p.update(kwargs)
            p.set_label(lbl)
            lbl = '_nolegend_'

    if binsgiven:
        if orientation == 'vertical':
            self.update_datalim([(bins[0],0), (bins[-1],0)], updatey=False)
        else:
            self.update_datalim([(0,bins[0]), (0,bins[-1])], updatex=False)

    self.set_autoscalex_on(_saved_autoscalex)
    self.set_autoscaley_on(_saved_autoscaley)
    self.autoscale_view()

    if nx == 1:
        return n[0], bins, cbook.silent_list('Patch', patches[0])
    else:
        return n, bins, cbook.silent_list('Lists of Patches', patches)
示例#30
0
def kepcotrendsc(infile, outfile, bvfile, listbv, fitmethod, fitpower, iterate,
                 sigma, maskfile, scinterp, plot, clobber, verbose, logfile,
                 status):
    """
	Setup the kepcotrend environment
	
	infile: 
	the input file in the FITS format obtained from MAST
	
	outfile:
	The output will be a fits file in the same style as the input file but with two additional columns: CBVSAP_MODL and CBVSAP_FLUX. The first of these is the best fitting linear combination of basis vectors. The second is the new flux with the basis vector sum subtracted. This is the new flux value. 
	
	plot:
	either True or False if you want to see a plot of the light curve
	The top plot shows the original light curve in blue and the sum of basis vectors in red
	The bottom plot has had the basis vector sum subracted
	
	bvfile:
	the name of the FITS file containing the basis vectors

	listbv:
	the basis vectors to fit to the data
	
	fitmethod:
	fit using either the 'llsq' or the 'simplex' method. 'llsq' is usually the correct one to use because as the basis vectors are orthogonal. Simplex gives you option of using a different merit function - ie. you can minimise the least absolute residual instead of the least squares which weights outliers less
	
	fitpower:
	if using a simplex you can chose your own power in the metir function - i.e. the merit function minimises abs(Obs - Mod)^P. P=2 is least squares, P = 1 minimises least absolutes
	
	iterate:
	should the program fit the basis vectors to the light curve data then remove data points further than 'sigma' from the fit and then refit
	
	maskfile:
	this is the name of a mask file which can be used to define regions of the flux time series to exclude from the fit. The easiest way to create this is by using keprange from the PyKE set of tools. You can also make this yourself with two BJDs on each line in the file specifying the beginning and ending date of the region to exclude.
	
	scinterp:
	the basis vectors are only calculated for long cadence data, therefore if you want to use short cadence data you have to interpolate the basis vectors. There are several methods to do this, the best of these probably being nearest which picks the value of the nearest long cadence data point.
	The options available are None|linear|nearest|zero|slinear|quadratic|cubic
	If you are using short cadence data don't choose none
	"""
    # log the call
    hashline = '----------------------------------------------------------------------------'
    kepmsg.log(logfile, hashline, verbose)
    call = 'KEPCOTREND -- '
    call += 'infile=' + infile + ' '
    call += 'outfile=' + outfile + ' '
    call += 'bvfile=' + bvfile + ' '
    #	call += 'numpcomp= '+str(numpcomp)+' '
    call += 'listbv= ' + str(listbv) + ' '
    call += 'fitmethod=' + str(fitmethod) + ' '
    call += 'fitpower=' + str(fitpower) + ' '
    iterateit = 'n'
    if (iterate): iterateit = 'y'
    call += 'iterate=' + iterateit + ' '
    call += 'sigma_clip=' + str(sigma) + ' '
    call += 'mask_file=' + maskfile + ' '
    call += 'scinterp=' + str(scinterp) + ' '
    plotit = 'n'
    if (plot): plotit = 'y'
    call += 'plot=' + plotit + ' '
    overwrite = 'n'
    if (clobber): overwrite = 'y'
    call += 'clobber=' + overwrite + ' '
    chatter = 'n'
    if (verbose): chatter = 'y'
    call += 'verbose=' + chatter + ' '
    call += 'logfile=' + logfile
    kepmsg.log(logfile, call + '\n', verbose)

    # start time
    kepmsg.clock('KEPCOTREND started at', logfile, verbose)

    # test log file
    logfile = kepmsg.test(logfile)

    # clobber output file
    if clobber:
        status = kepio.clobber(outfile, logfile, verbose)
    if kepio.fileexists(outfile):
        message = 'ERROR -- KEPCOTREND: ' + outfile + ' exists. Use --clobber'
        status = kepmsg.err(logfile, message, verbose)

    # open input file
    if status == 0:
        instr, status = kepio.openfits(infile, 'readonly', logfile, verbose)
        tstart, tstop, bjdref, cadence, status = kepio.timekeys(
            instr, infile, logfile, verbose, status)

    # fudge non-compliant FITS keywords with no values
    if status == 0:
        instr = kepkey.emptykeys(instr, file, logfile, verbose)

    if status == 0:
        if not kepio.fileexists(bvfile):
            message = 'ERROR -- KEPCOTREND: ' + bvfile + ' does not exist.'
            status = kepmsg.err(logfile, message, verbose)

    #lsq_sq - nonlinear least squares fitting and simplex_abs have been removed from the option in PyRAF but they are still in the code!
    if status == 0:
        if fitmethod not in [
                'llsq', 'matrix', 'lst_sq', 'simplex_abs', 'simplex'
        ]:
            message = 'Fit method must either: llsq, matrix, lst_sq or simplex'
            status = kepmsg.err(logfile, message, verbose)

    if status == 0:
        if not is_numlike(fitpower) and fitpower is not None:
            message = 'Fit power must be an real number or None'
            status = kepmsg.err(logfile, message, verbose)

    if status == 0:
        if fitpower is None:
            fitpower = 1.

    # input data
    if status == 0:
        short = False
        try:
            test = str(instr[0].header['FILEVER'])
            version = 2
        except KeyError:
            version = 1

        table = instr[1].data
        if version == 1:
            if str(instr[1].header['DATATYPE']) == 'long cadence':
                #print 'Light curve was taken in Lond Cadence mode!'
                quarter = str(instr[1].header['QUARTER'])
                module = str(instr[1].header['MODULE'])
                output = str(instr[1].header['OUTPUT'])
                channel = str(instr[1].header['CHANNEL'])

                lc_cad_o = table.field('cadence_number')
                lc_date_o = table.field('barytime')
                lc_flux_o = table.field(
                    'ap_raw_flux') / 1625.3468  #convert to e-/s
                lc_err_o = table.field(
                    'ap_raw_err') / 1625.3468  #convert to e-/s
            elif str(instr[1].header['DATATYPE']) == 'short cadence':
                short = True
                #print 'Light curve was taken in Short Cadence mode!'
                quarter = str(instr[1].header['QUARTER'])
                module = str(instr[1].header['MODULE'])
                output = str(instr[1].header['OUTPUT'])
                channel = str(instr[1].header['CHANNEL'])

                lc_cad_o = table.field('cadence_number')
                lc_date_o = table.field('barytime')
                lc_flux_o = table.field(
                    'ap_raw_flux') / 54.178  #convert to e-/s
                lc_err_o = table.field('ap_raw_err') / 54.178  #convert to e-/s

        elif version == 2:
            if str(instr[0].header['OBSMODE']) == 'long cadence':
                #print 'Light curve was taken in Long Cadence mode!'

                quarter = str(instr[0].header['QUARTER'])
                module = str(instr[0].header['MODULE'])
                output = str(instr[0].header['OUTPUT'])
                channel = str(instr[0].header['CHANNEL'])

                lc_cad_o = table.field('CADENCENO')
                lc_date_o = table.field('TIME')
                lc_flux_o = table.field('SAP_FLUX')
                lc_err_o = table.field('SAP_FLUX_ERR')
            elif str(instr[0].header['OBSMODE']) == 'short cadence':
                #print 'Light curve was taken in Short Cadence mode!'
                short = True
                quarter = str(instr[0].header['QUARTER'])
                module = str(instr[0].header['MODULE'])
                output = str(instr[0].header['OUTPUT'])
                channel = str(instr[0].header['CHANNEL'])

                lc_cad_o = table.field('CADENCENO')
                lc_date_o = table.field('TIME')
                lc_flux_o = table.field('SAP_FLUX')
                lc_err_o = table.field('SAP_FLUX_ERR')

        if str(quarter) == str(4) and version == 1:
            lc_cad_o = lc_cad_o[lc_cad_o >= 11914]
            lc_date_o = lc_date_o[lc_cad_o >= 11914]
            lc_flux_o = lc_flux_o[lc_cad_o >= 11914]
            lc_err_o = lc_err_o[lc_cad_o >= 11914]

        # bvfilename = '%s/Q%s_%s_%s_map.txt' %(bvfile,quarter,module,output)
        # if str(quarter) == str(5):
        # 	bvdata = genfromtxt(bvfilename)
        # elif str(quarter) == str(3) or str(quarter) == str(4):
        # 	bvdata = genfromtxt(bvfilename,skip_header=22)
        # elif str(quarter) == str(1):
        # 	bvdata = genfromtxt(bvfilename,skip_header=10)
        # else:
        # 	bvdata = genfromtxt(bvfilename,skip_header=13)

        if short and scinterp == 'None':
            message = 'You cannot select None as the interpolation method because you are using short cadence data and therefore must use some form of interpolation. I reccommend nearest if you are unsure.'
            status = kepmsg.err(logfile, message, verbose)

        bvfiledata = pyfits.open(bvfile)
        bvdata = bvfiledata['MODOUT_%s_%s' % (module, output)].data

        if int(bvfiledata[0].header['QUARTER']) != int(quarter):
            message = 'CBV file and light curve file are from different quarters. CBV file is from Q%s and light curve is from Q%s' % (
                int(bvfiledata[0].header['QUARTER']), int(quarter))
            status = kepmsg.err(logfile, message, verbose)

    if status == 0:
        if int(quarter) == 4 and int(module) == 3:
            message = 'Approximately twenty days into Q4 Module 3 failed. As a result, Q4 light curves contain these 20 day of data. However, we do not calculate CBVs for this section of data.'
            status = kepmsg.err(logfile, message, verbose)

    if status == 0:

        #cut out infinites and zero flux columns
        lc_cad, lc_date, lc_flux, lc_err, bad_data = cutBadData(
            lc_cad_o, lc_date_o, lc_flux_o, lc_err_o)

        #get a list of basis vectors to use from the list given
        #accept different seperators
        listbv = listbv.strip()
        if listbv[1] in [' ', ',', ':', ';', '|', ', ']:
            separator = str(listbv)[1]
        else:
            message = 'You must separate your basis vector numbers to use with \' \' \',\' \':\' \';\' or \'|\' and the first basis vector to use must be between 1 and 9'
            status = kepmsg.err(logfile, message, verbose)

    if status == 0:
        bvlist = fromstring(listbv, dtype=int, sep=separator)

        if bvlist[0] == 0:
            message = 'Must use at least one basis vector'
            status = kepmsg.err(logfile, message, verbose)
    if status == 0:
        #pcomps = get_pcomp(pcompdata,n_comps,lc_cad)
        # if str(quarter) == str(5):
        # 	bvectors = get_pcomp_list(bvdata,bvlist,lc_cad)
        # else:
        #	bvectors = get_pcomp_list_newformat(bvdata,bvlist,lc_cad)

        if short:
            bvdata.field('CADENCENO')[:] = (((bvdata.field('CADENCENO')[:] +
                                              (7.5 / 15.)) * 30.) -
                                            11540.).round()

        bvectors, in1derror = get_pcomp_list_newformat(bvdata, bvlist, lc_cad,
                                                       short, scinterp)

        if in1derror:
            message = 'It seems that you have an old version of numpy which does not have the in1d function included. Please update your version of numpy to a version 1.4.0 or later'
            status = kepmsg.err(logfile, message, verbose)
    if status == 0:

        medflux = median(lc_flux)
        n_flux = (lc_flux / medflux) - 1
        n_err = sqrt(pow(lc_err, 2) / pow(medflux, 2))

        #plt.errorbar(lc_cad,n_flux,yerr=n_err)
        #plt.errorbar(lc_cad,lc_flux,yerr=lc_err)

        #n_err = median(lc_err/lc_flux) * n_flux
        #print n_err

        #does an iterative least squares fit
        #t1 = do_leastsq(pcomps,lc_cad,n_flux)
        #

        if maskfile != '':
            domasking = True
            if not kepio.fileexists(maskfile):
                message = 'Maskfile %s does not exist' % maskfile
                status = kepmsg.err(logfile, message, verbose)
        else:
            domasking = False

    if status == 0:
        if domasking:

            lc_date_masked = copy(lc_date)
            n_flux_masked = copy(n_flux)
            lc_cad_masked = copy(lc_cad)
            n_err_masked = copy(n_err)
            maskdata = atleast_2d(genfromtxt(maskfile, delimiter=','))
            #make a mask of True values incase there are not regions in maskfile to exclude.
            mask = zeros(len(lc_date_masked)) == 0.
            for maskrange in maskdata:
                if version == 1:
                    start = maskrange[0] - 2400000.0
                    end = maskrange[1] - 2400000.0
                elif version == 2:
                    start = maskrange[0] - 2454833.
                    end = maskrange[1] - 2454833.
                masknew = logical_xor(lc_date < start, lc_date > end)
                mask = logical_and(mask, masknew)

            lc_date_masked = lc_date_masked[mask]
            n_flux_masked = n_flux_masked[mask]
            lc_cad_masked = lc_cad_masked[mask]
            n_err_masked = n_err_masked[mask]
        else:
            lc_date_masked = copy(lc_date)
            n_flux_masked = copy(n_flux)
            lc_cad_masked = copy(lc_cad)
            n_err_masked = copy(n_err)

        #pcomps = get_pcomp(pcompdata,n_comps,lc_cad)

        bvectors_masked, hasin1d = get_pcomp_list_newformat(
            bvdata, bvlist, lc_cad_masked, short, scinterp)

        if (iterate) and sigma is None:
            message = 'If fitting iteratively you must specify a clipping range'
            status = kepmsg.err(logfile, message, verbose)

    if status == 0:
        #uses Pvals = yhat * U_transpose
        if (iterate):
            coeffs, fittedmask = do_lst_iter(bvectors_masked, lc_cad_masked,
                                             n_flux_masked, sigma, 50.,
                                             fitmethod, fitpower)
        else:
            if fitmethod == 'matrix' and domasking:
                coeffs = do_lsq_uhat(bvectors_masked, lc_cad_masked,
                                     n_flux_masked, False)
            if fitmethod == 'llsq' and domasking:
                coeffs = do_lsq_uhat(bvectors_masked, lc_cad_masked,
                                     n_flux_masked, False)
            elif fitmethod == 'lst_sq':
                coeffs = do_lsq_nlin(bvectors_masked, lc_cad_masked,
                                     n_flux_masked)
            elif fitmethod == 'simplex_abs':
                coeffs = do_lsq_fmin(bvectors_masked, lc_cad_masked,
                                     n_flux_masked)
            elif fitmethod == 'simplex':
                coeffs = do_lsq_fmin_pow(bvectors_masked, lc_cad_masked,
                                         n_flux_masked, fitpower)
            else:
                coeffs = do_lsq_uhat(bvectors_masked, lc_cad_masked,
                                     n_flux_masked)

        flux_after = (get_newflux(n_flux, bvectors, coeffs) + 1) * medflux
        flux_after_masked = (
            get_newflux(n_flux_masked, bvectors_masked, coeffs) + 1) * medflux
        bvsum = get_pcompsum(bvectors, coeffs)

        bvsum_masked = get_pcompsum(bvectors_masked, coeffs)

        #print 'chi2: ' + str(chi2_gtf(n_flux,bvsum,n_err,2.*len(n_flux)-2))
        #print 'rms: ' + str(rms(n_flux,bvsum))

        bvsum_nans = putInNans(bad_data, bvsum)
        flux_after_nans = putInNans(bad_data, flux_after)

    if plot and status == 0:
        bvsum_un_norm = medflux * (1 - bvsum)
        #bvsum_un_norm = 0-bvsum
        #lc_flux = n_flux
        do_plot(lc_date, lc_flux, flux_after, bvsum_un_norm, lc_cad, bad_data,
                lc_cad_o, version)

    if status == 0:
        make_outfile(instr, outfile, flux_after_nans, bvsum_nans, version)

    # close input file
    if status == 0:
        status = kepio.closefits(instr, logfile, verbose)

        #print some results to screen:
        print('      -----      ')
        if iterate:
            flux_fit = n_flux_masked[fittedmask]
            sum_fit = bvsum_masked[fittedmask]
            err_fit = n_err_masked[fittedmask]
        else:
            flux_fit = n_flux_masked
            sum_fit = bvsum_masked
            err_fit = n_err_masked
        print('reduced chi2: ' + str(
            chi2_gtf(flux_fit, sum_fit, err_fit,
                     len(flux_fit) - len(coeffs))))
        print('rms: ' + str(medflux * rms(flux_fit, sum_fit)))
        for i in range(len(coeffs)):
            print('Coefficient of CBV #%s: %s' % (i + 1, coeffs[i]))
        print('      -----      ')

    # end time
    if (status == 0):
        message = 'KEPCOTREND completed at'
    else:
        message = '\nKEPCOTTREND aborted at'
    kepmsg.clock(message, logfile, verbose)

    return
示例#31
0
def draw_networkx_edges(G, pos,
                        edgelist=None,
                        width=1.0,
                        edge_color='k',
                        style='solid',
                        alpha=1.0,
                        arrowstyle='-|>',
                        arrowsize=10,
                        edge_cmap=None,
                        edge_vmin=None,
                        edge_vmax=None,
                        ax=None,
                        arrows=True,
                        label=None,
                        node_size=300,
                        nodelist=None,
                        node_shape="o",
                        connectionstyle='arc3',
                        **kwds):
    """Draw the edges of the graph G.

    This draws only the edges of the graph G.

    Parameters
    ----------
    G : graph
       A networkx graph

    pos : dictionary
       A dictionary with nodes as keys and positions as values.
       Positions should be sequences of length 2.

    edgelist : collection of edge tuples
       Draw only specified edges(default=G.edges())

    width : float, or array of floats
       Line width of edges (default=1.0)

    edge_color : color string, or array of floats
       Edge color. Can be a single color format string (default='r'),
       or a sequence of colors with the same length as edgelist.
       If numeric values are specified they will be mapped to
       colors using the edge_cmap and edge_vmin,edge_vmax parameters.

    style : string
       Edge line style (default='solid') (solid|dashed|dotted,dashdot)

    alpha : float
       The edge transparency (default=1.0)

    edge_ cmap : Matplotlib colormap
       Colormap for mapping intensities of edges (default=None)

    edge_vmin,edge_vmax : floats
       Minimum and maximum for edge colormap scaling (default=None)

    ax : Matplotlib Axes object, optional
       Draw the graph in the specified Matplotlib axes.

    arrows : bool, optional (default=True)
       For directed graphs, if True draw arrowheads.
       Note: Arrows will be the same color as edges.

    arrowstyle : str, optional (default='-|>')
       For directed graphs, choose the style of the arrow heads.
       See :py:class: `matplotlib.patches.ArrowStyle` for more
       options.

    arrowsize : int, optional (default=10)
       For directed graphs, choose the size of the arrow head head's length and
       width. See :py:class: `matplotlib.patches.FancyArrowPatch` for attribute
       `mutation_scale` for more info.

    label : [None| string]
       Label for legend

    Returns
    -------
    matplotlib.collection.LineCollection
        `LineCollection` of the edges

    list of matplotlib.patches.FancyArrowPatch
        `FancyArrowPatch` instances of the directed edges

    Depending whether the drawing includes arrows or not.

    Notes
    -----
    For directed graphs, arrows are drawn at the head end.  Arrows can be
    turned off with keyword arrows=False. Be sure to include `node_size' as a
    keyword argument; arrows are drawn considering the size of nodes.

    Examples
    --------
    >>> G = nx.dodecahedral_graph()
    >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))

    >>> G = nx.DiGraph()
    >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)])
    >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
    >>> alphas = [0.3, 0.4, 0.5]
    >>> for i, arc in enumerate(arcs):  # change alpha values of arcs
    ...     arc.set_alpha(alphas[i])

    Also see the NetworkX drawing examples at
    https://networkx.github.io/documentation/latest/auto_examples/index.html

    See Also
    --------
    draw()
    draw_networkx()
    draw_networkx_nodes()
    draw_networkx_labels()
    draw_networkx_edge_labels()
    """
    try:
        import matplotlib
        import matplotlib.pyplot as plt
        import matplotlib.cbook as cb
        from matplotlib.colors import colorConverter, Colormap, Normalize
        from matplotlib.collections import LineCollection
        from matplotlib.patches import FancyArrowPatch, ConnectionStyle
        import numpy as np
    except ImportError:
        raise ImportError("Matplotlib required for draw()")
    except RuntimeError:
        print("Matplotlib unable to open display")
        raise

    if ax is None:
        ax = plt.gca()

    if edgelist is None:
        edgelist = list(G.edges())

    if not edgelist or len(edgelist) == 0:  # no edges!
        return None

    if nodelist is None:
        nodelist = list(G.nodes())

    # set edge positions
    edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])

    if not cb.iterable(width):
        lw = (width,)
    else:
        lw = width

    if not is_string_like(edge_color) \
            and cb.iterable(edge_color) \
            and len(edge_color) == len(edge_pos):
        if np.alltrue([is_string_like(c) for c in edge_color]):
            # (should check ALL elements)
            # list of color letters such as ['k','r','k',...]
            edge_colors = tuple([colorConverter.to_rgba(c, alpha)
                                 for c in edge_color])
        elif np.alltrue([not is_string_like(c) for c in edge_color]):
            # If color specs are given as (rgb) or (rgba) tuples, we're OK
            if np.alltrue([cb.iterable(c) and len(c) in (3, 4)
                           for c in edge_color]):
                edge_colors = tuple(edge_color)
            else:
                # numbers (which are going to be mapped with a colormap)
                edge_colors = None
        else:
            raise ValueError('edge_color must contain color names or numbers')
    else:
        if is_string_like(edge_color) or len(edge_color) == 1:
            edge_colors = (colorConverter.to_rgba(edge_color, alpha), )
        else:
            msg = 'edge_color must be a color or list of one color per edge'
            raise ValueError(msg)

    if (not G.is_directed() or not arrows):
        edge_collection = LineCollection(edge_pos,
                                         colors=edge_colors,
                                         linewidths=lw,
                                         antialiaseds=(1,),
                                         linestyle=style,
                                         transOffset=ax.transData,
                                         )

        edge_collection.set_zorder(1)  # edges go behind nodes
        edge_collection.set_label(label)
        ax.add_collection(edge_collection)

        # Note: there was a bug in mpl regarding the handling of alpha values
        # for each line in a LineCollection. It was fixed in matplotlib by
        # r7184 and r7189 (June 6 2009). We should then not set the alpha
        # value globally, since the user can instead provide per-edge alphas
        # now.  Only set it globally if provided as a scalar.
        if cb.is_numlike(alpha):
            edge_collection.set_alpha(alpha)

        if edge_colors is None:
            if edge_cmap is not None:
                assert(isinstance(edge_cmap, Colormap))
            edge_collection.set_array(np.asarray(edge_color))
            edge_collection.set_cmap(edge_cmap)
            if edge_vmin is not None or edge_vmax is not None:
                edge_collection.set_clim(edge_vmin, edge_vmax)
            else:
                edge_collection.autoscale()
        return edge_collection

    arrow_collection = None

    if G.is_directed() and arrows:
        # Note: Waiting for someone to implement arrow to intersection with
        # marker.  Meanwhile, this works well for polygons with more than 4
        # sides and circle.

        def to_marker_edge(marker_size, marker):
            if marker in "s^>v<d":  # `large` markers need extra space
                return np.sqrt(2 * marker_size) / 2
            else:
                return np.sqrt(marker_size) / 2

        # Draw arrows with `matplotlib.patches.FancyarrowPatch`
        arrow_collection = []
        mutation_scale = arrowsize  # scale factor of arrow head
        arrow_colors = edge_colors
        if arrow_colors is None:
            if edge_cmap is not None:
                assert(isinstance(edge_cmap, Colormap))
            else:
                edge_cmap = plt.get_cmap()  # default matplotlib colormap
            if edge_vmin is None:
                edge_vmin = min(edge_color)
            if edge_vmax is None:
                edge_vmax = max(edge_color)
            color_normal = Normalize(vmin=edge_vmin, vmax=edge_vmax)

        for i, (src, dst) in enumerate(edge_pos):
            x1, y1 = src
            x2, y2 = dst
            arrow_color = None
            line_width = None
            shrink_source = 0  # space from source to tail
            shrink_target = 0  # space from  head to target
            if cb.iterable(node_size):  # many node sizes
                src_node, dst_node = edgelist[i]
                index_node = nodelist.index(dst_node)
                marker_size = node_size[index_node]
                shrink_target = to_marker_edge(marker_size, node_shape)
            else:
                shrink_target = to_marker_edge(node_size, node_shape)
            if arrow_colors is None:
                arrow_color = edge_cmap(color_normal(edge_color[i]))
            elif len(arrow_colors) > 1:
                arrow_color = arrow_colors[i]
            else:
                arrow_color = arrow_colors[0]
            if len(lw) > 1:
                line_width = lw[i]
            else:
                line_width = lw[0]
            arrow = FancyArrowPatch((x1, y1), (x2, y2),
                                    arrowstyle=arrowstyle,
                                    shrinkA=shrink_source,
                                    shrinkB=shrink_target,
                                    mutation_scale=mutation_scale,
                                    connectionstyle=connectionstyle,
                                    color=arrow_color,
                                    linewidth=line_width,
                                    zorder=1)  # arrows go behind nodes

            # There seems to be a bug in matplotlib to make collections of
            # FancyArrowPatch instances. Until fixed, the patches are added
            # individually to the axes instance.
            arrow_collection.append(arrow)
            ax.add_patch(arrow)

    # update view
    minx = np.amin(np.ravel(edge_pos[:, :, 0]))
    maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
    miny = np.amin(np.ravel(edge_pos[:, :, 1]))
    maxy = np.amax(np.ravel(edge_pos[:, :, 1]))

    w = maxx - minx
    h = maxy - miny
    padx,  pady = 0.05 * w, 0.05 * h
    corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
    ax.update_datalim(corners)
    ax.autoscale_view()

    return arrow_collection
示例#32
0
def draw_networkx_edges(G, pos,
                        edgelist=None,
                        width=1.0,
                        edge_color='k',
                        style='solid',
                        alpha=None,
                        edge_cmap=None,
                        edge_vmin=None,
                        edge_vmax=None, 
                        ax=None,
                        arrows=True,
                        **kwds):
    """Draw the edges of the graph G

    This draws only the edges of the graph G.

    pos is a dictionary keyed by vertex with a two-tuple
    of x-y positions as the value.
    See networkx.layout for functions that compute node positions.

    edgelist is an optional list of the edges in G to be drawn.
    If provided, only the edges in edgelist will be drawn. 

    edgecolor can be a list of matplotlib color letters such as 'k' or
    'b' that lists the color of each edge; the list must be ordered in
    the same way as the edge list. Alternatively, this list can contain
    numbers and those number are mapped to a color scale using the color
    map edge_cmap.  Finally, it can also be a list of (r,g,b) or (r,g,b,a)
    tuples, in which case these will be used directly to color the edges.  If
    the latter mode is used, you should not provide a value for alpha, as it
    would be applied globally to all lines.
    
    For directed graphs, "arrows" (actually just thicker stubs) are drawn
    at the head end.  Arrows can be turned off with keyword arrows=False.

    See draw_networkx for the list of other optional parameters.

    """
    try:
        import matplotlib
        import matplotlib.pylab as pylab
        import numpy as np
        from matplotlib.colors import colorConverter,Colormap
        from matplotlib.collections import LineCollection
    except ImportError:
        raise ImportError("Matplotlib required for draw()")
    except RuntimeError:
        pass # unable to open display

    if ax is None:
        ax=pylab.gca()

    if edgelist is None:
        edgelist=G.edges()

    if not edgelist or len(edgelist)==0: # no edges!
        return None

    # set edge positions
    edge_pos=np.asarray([(pos[e[0]],pos[e[1]]) for e in edgelist])
    
    if not cb.iterable(width):
        lw = (width,)
    else:
        lw = width

    if not cb.is_string_like(edge_color) \
           and cb.iterable(edge_color) \
           and len(edge_color)==len(edge_pos):
        if np.alltrue([cb.is_string_like(c) 
                         for c in edge_color]):
            # (should check ALL elements)
            # list of color letters such as ['k','r','k',...]
            edge_colors = tuple([colorConverter.to_rgba(c,alpha) 
                                 for c in edge_color])
        elif np.alltrue([not cb.is_string_like(c) 
                           for c in edge_color]):
            # If color specs are given as (rgb) or (rgba) tuples, we're OK
            if np.alltrue([cb.iterable(c) and len(c) in (3,4)
                             for c in edge_color]):
                edge_colors = tuple(edge_color)
                alpha=None
            else:
                # numbers (which are going to be mapped with a colormap)
                edge_colors = None
        else:
            raise ValueError('edge_color must consist of either color names or numbers')
    else:
        if len(edge_color)==1:
            edge_colors = ( colorConverter.to_rgba(edge_color, alpha), )
        else:
            raise ValueError('edge_color must be a single color or list of exactly m colors where m is the number or edges')
    edge_collection = LineCollection(edge_pos,
                                     colors       = edge_colors,
                                     linewidths   = lw,
                                     antialiaseds = (1,),
                                     linestyle    = style,     
                                     transOffset = ax.transData,             
                                     )

    # Note: there was a bug in mpl regarding the handling of alpha values for
    # each line in a LineCollection.  It was fixed in matplotlib in r7184 and
    # r7189 (June 6 2009).  We should then not set the alpha value globally,
    # since the user can instead provide per-edge alphas now.  Only set it
    # globally if provided as a scalar.
    if cb.is_numlike(alpha):
        edge_collection.set_alpha(alpha)

    # need 0.87.7 or greater for edge colormaps.  No checks done, this will
    # just not work with an older mpl
    if edge_colors is None:
        if edge_cmap is not None: assert(isinstance(edge_cmap, Colormap))
        edge_collection.set_array(np.asarray(edge_color))
        edge_collection.set_cmap(edge_cmap)
        if edge_vmin is not None or edge_vmax is not None:
            edge_collection.set_clim(edge_vmin, edge_vmax)
        else:
            edge_collection.autoscale()
        pylab.sci(edge_collection)

    arrow_collection=None

    if G.is_directed() and arrows:

        # a directed graph hack
        # draw thick line segments at head end of edge
        # waiting for someone else to implement arrows that will work 
        arrow_colors = ( colorConverter.to_rgba('k', alpha), )
        a_pos=[]
        p=1.0-0.25 # make head segment 25 percent of edge length
        for src,dst in edge_pos:
            x1,y1=src
            x2,y2=dst
            dx=x2-x1 # x offset
            dy=y2-y1 # y offset
            d=np.sqrt(float(dx**2+dy**2)) # length of edge
            if d==0: # source and target at same position
                continue
            if dx==0: # vertical edge
                xa=x2
                ya=dy*p+y1
            if dy==0: # horizontal edge
                ya=y2
                xa=dx*p+x1
            else:
                theta=np.arctan2(dy,dx)
                xa=p*d*np.cos(theta)+x1
                ya=p*d*np.sin(theta)+y1
                
            a_pos.append(((xa,ya),(x2,y2)))

        arrow_collection = LineCollection(a_pos,
                                colors       = arrow_colors,
                                linewidths   = [4*ww for ww in lw],
                                antialiaseds = (1,),
                                transOffset = ax.transData,             
                                )
        
    # update view        
    minx = np.amin(np.ravel(edge_pos[:,:,0]))
    maxx = np.amax(np.ravel(edge_pos[:,:,0]))
    miny = np.amin(np.ravel(edge_pos[:,:,1]))
    maxy = np.amax(np.ravel(edge_pos[:,:,1]))

    w = maxx-minx
    h = maxy-miny
    padx, pady = 0.05*w, 0.05*h
    corners = (minx-padx, miny-pady), (maxx+padx, maxy+pady)
    ax.update_datalim( corners)
    ax.autoscale_view()

    edge_collection.set_zorder(1) # edges go behind nodes            
    ax.add_collection(edge_collection)
    if arrow_collection:
        arrow_collection.set_zorder(1) # edges go behind nodes            
        ax.add_collection(arrow_collection)
        
    return edge_collection
示例#33
0
文件: test.py 项目: untra/kgraphs
def draw_networkx_edges(G,
                        pos,
                        edgelist=None,
                        width=1.0,
                        edge_color='k',
                        style='solid',
                        alpha=1.0,
                        edge_cmap=None,
                        edge_vmin=None,
                        edge_vmax=None,
                        ax=None,
                        arrows=True,
                        arrowstyle='thick',
                        label=None,
                        **kwds):
    try:
        import matplotlib
        import matplotlib.pyplot as plt
        import matplotlib.cbook as cb
        import matplotlib.patches as patches
        from matplotlib.colors import colorConverter, Colormap
        from matplotlib.collections import LineCollection
        from matplotlib.path import Path
        import numpy
    except ImportError:
        raise ImportError("Matplotlib required for draw()")
    except RuntimeError:
        print("Matplotlib unable to open display")
        raise
    # print "drawing_edges"

    if ax is None:
        ax = plt.gca()

    if edgelist is None:
        edgelist = G.edges()

    if not edgelist or len(edgelist) == 0:  # no edges!
        return None

    # set edge positions
    edge_pos = numpy.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
    # for e in edge_pos:
    #   print e
    if not cb.iterable(width):
        lw = (width, )
    else:
        lw = width

    if not cb.is_string_like(edge_color) \
           and cb.iterable(edge_color) \
           and len(edge_color) == len(edge_pos):
        if numpy.alltrue([cb.is_string_like(c) for c in edge_color]):
            # (should check ALL elements)
            # list of color letters such as ['k','r','k',...]
            edge_colors = tuple(
                [colorConverter.to_rgba(c, alpha) for c in edge_color])
        elif numpy.alltrue([not cb.is_string_like(c) for c in edge_color]):
            # If color specs are given as (rgb) or (rgba) tuples, we're OK
            if numpy.alltrue(
                [cb.iterable(c) and len(c) in (3, 4) for c in edge_color]):
                edge_colors = tuple(edge_color)
            else:
                # numbers (which are going to be mapped with a colormap)
                edge_colors = None
        else:
            raise ValueError(
                'edge_color must consist of either color names or numbers')
    else:
        if cb.is_string_like(edge_color) or len(edge_color) == 1:
            edge_colors = (colorConverter.to_rgba(edge_color, alpha), )
        else:
            raise ValueError(
                'edge_color must be a single color or list of exactly m colors where m is the number or edges'
            )

    edge_collection = LineCollection(
        edge_pos,
        colors=edge_colors,
        linewidths=lw,
        antialiaseds=(1, ),
        linestyle=style,
        transOffset=ax.transData,
    )

    # print type(edge_collection)

    edge_collection.set_zorder(1)  # edges go behind nodes
    edge_collection.set_label(label)
    ax.add_collection(edge_collection)

    # Note: there was a bug in mpl regarding the handling of alpha values for
    # each line in a LineCollection.  It was fixed in matplotlib in r7184 and
    # r7189 (June 6 2009).  We should then not set the alpha value globally,
    # since the user can instead provide per-edge alphas now.  Only set it
    # globally if provided as a scalar.
    if cb.is_numlike(alpha):
        edge_collection.set_alpha(alpha)

    if edge_colors is None:
        if edge_cmap is not None:
            assert (isinstance(edge_cmap, Colormap))
        edge_collection.set_array(numpy.asarray(edge_color))
        edge_collection.set_cmap(edge_cmap)
        if edge_vmin is not None or edge_vmax is not None:
            edge_collection.set_clim(edge_vmin, edge_vmax)
        else:
            edge_collection.autoscale()

    arrow_collection = None

    if G.is_directed() and arrows:

        # a directed graph hack-fix
        # draws arrows at each
        # waiting for someone else to implement arrows that will work
        arrow_colors = edge_colors
        a_pos = []
        p = .1  # make arrows 10% of total length
        angle = 2.7  #angle for arrows
        for src, dst in edge_pos:
            x1, y1 = src
            x2, y2 = dst
            dx = x2 - x1  # x offset
            dy = y2 - y1  # y offset
            d = numpy.sqrt(float(dx**2 + dy**2))  # length of edge
            theta = numpy.arctan2(dy, dx)
            if d == 0:  # source and target at same position
                continue
            if dx == 0:  # vertical edge
                xa = x2
                ya = dy + y1
            if dy == 0:  # horizontal edge
                ya = y2
                xa = dx + x1
            else:
                # xa = p*d*numpy.cos(theta)+x1
                # ya = p*d*numpy.sin(theta)+y1
                #corrects the endpoints to better draw
                x2 -= .04 * numpy.cos(theta)
                y2 -= .04 * numpy.sin(theta)
                lx1 = p * d * numpy.cos(theta + angle) + (x2)
                lx2 = p * d * numpy.cos(theta - angle) + (x2)
                ly1 = p * d * numpy.sin(theta + angle) + (y2)
                ly2 = p * d * numpy.sin(theta - angle) + (y2)

            a_pos.append(((lx1, ly1), (x2, y2)))
            a_pos.append(((lx2, ly2), (x2, y2)))

        arrow_collection = LineCollection(
            a_pos,
            colors=arrow_colors,
            linewidths=[1 * ww for ww in lw],
            antialiaseds=(1, ),
            transOffset=ax.transData,
        )

        arrow_collection.set_zorder(1)  # edges go behind nodes
        arrow_collection.set_label(label)
        # print type(ax)
        ax.add_collection(arrow_collection)

    #drawing self loops

    d = 1
    c = 0.0707
    selfedges = []
    verts = [
        (0.1 * d - 0.1 * d, 0.0),  # P0
        (c * d - 0.1 * d, c * d),  # P0
        (0.0 - 0.1 * d, 0.1 * d),  # P0
        (-c * d - 0.1 * d, c * d),  # P0
        (-0.1 * d - 0.1 * d, 0.0),  # P0
        (-c * d - 0.1 * d, -c * d),  # P0
        (0.0 - 0.1 * d, -0.1 * d),  # P0
        (c * d - 0.1 * d, -c * d),  # P0
        (0.1 * d - 0.1 * d, 0.0)
    ]
    # print verts

    codes = [
        Path.MOVETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
    ]

    for e in edge_pos:
        if (numpy.array_equal(e[0], e[1])):
            nodes = verts[:]
            for i in range(len(nodes)):
                nodes[i] += e[0]
            # print nodes
            path = Path(nodes, codes)
            patch = patches.PathPatch(path,
                                      color=None,
                                      facecolor=None,
                                      edgecolor=edge_colors[0],
                                      fill=False,
                                      lw=4)
            ax.add_patch(patch)

    # update view
    minx = numpy.amin(numpy.ravel(edge_pos[:, :, 0]))
    maxx = numpy.amax(numpy.ravel(edge_pos[:, :, 0]))
    miny = numpy.amin(numpy.ravel(edge_pos[:, :, 1]))
    maxy = numpy.amax(numpy.ravel(edge_pos[:, :, 1]))

    w = maxx - minx
    h = maxy - miny
    padx, pady = 0.05 * w, 0.05 * h
    corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
    # print ax
    ax.update_datalim(corners)
    ax.autoscale_view()

    #    if arrow_collection:

    return edge_collection
示例#34
0
    def __init__(
        self,
        fig,
        rect,
        nrows_ncols,
        ngrids=None,
        direction="row",
        axes_pad=0.02,
        add_all=True,
        share_all=False,
        aspect=True,
        label_mode="L",
        cbar_mode=None,
        cbar_location="right",
        cbar_pad=None,
        cbar_size="5%",
        cbar_set_cax=True,
        axes_class=None,
    ):
        """
        Build an :class:`ImageGrid` instance with a grid nrows*ncols
        :class:`~matplotlib.axes.Axes` in
        :class:`~matplotlib.figure.Figure` *fig* with
        *rect=[left, bottom, width, height]* (in
        :class:`~matplotlib.figure.Figure` coordinates) or
        the subplot position code (e.g., "121").

        Optional keyword arguments:

          ================  ========  =========================================
          Keyword           Default   Description
          ================  ========  =========================================
          direction         "row"     [ "row" | "column" ]
          axes_pad          0.02      float| pad between axes given in inches
          add_all           True      [ True | False ]
          share_all         False     [ True | False ]
          aspect            True      [ True | False ]
          label_mode        "L"       [ "L" | "1" | "all" ]
          cbar_mode         None      [ "each" | "single" ]
          cbar_location     "right"   [ "right" | "top" ]
          cbar_pad          None
          cbar_size         "5%"
          cbar_set_cax      True      [ True | False ]
          axes_class        None      a type object which must be a subclass
                                      of :class:`~matplotlib.axes.Axes`
          ================  ========  =========================================

        *cbar_set_cax* : if True, each axes in the grid has a cax
          attribute that is bind to associated cbar_axes.
        """
        self._nrows, self._ncols = nrows_ncols

        if ngrids is None:
            ngrids = self._nrows * self._ncols
        else:
            if (ngrids > self._nrows * self._ncols) or (ngrids <= 0):
                raise Exception("")

        self.ngrids = ngrids

        self._axes_pad = axes_pad

        self._colorbar_mode = cbar_mode
        self._colorbar_location = cbar_location
        if cbar_pad is None:
            self._colorbar_pad = axes_pad
        else:
            self._colorbar_pad = cbar_pad

        self._colorbar_size = cbar_size

        self._init_axes_pad(axes_pad)

        if direction not in ["column", "row"]:
            raise Exception("")

        self._direction = direction

        if axes_class is None:
            axes_class = self._defaultLocatableAxesClass
            axes_class_args = {}
        else:
            if isinstance(axes_class, maxes.Axes):
                axes_class_args = {}
            else:
                axes_class, axes_class_args = axes_class

        self.axes_all = []
        self.axes_column = [[] for i in range(self._ncols)]
        self.axes_row = [[] for i in range(self._nrows)]

        self.cbar_axes = []

        h = []
        v = []
        if cbook.is_string_like(rect) or cbook.is_numlike(rect):
            self._divider = SubplotDivider(fig, rect, horizontal=h, vertical=v, aspect=aspect)
        elif len(rect) == 3:
            kw = dict(horizontal=h, vertical=v, aspect=aspect)
            self._divider = SubplotDivider(fig, *rect, **kw)
        elif len(rect) == 4:
            self._divider = Divider(fig, rect, horizontal=h, vertical=v, aspect=aspect)
        else:
            raise Exception("")

        rect = self._divider.get_position()

        # reference axes
        self._column_refax = [None for i in range(self._ncols)]
        self._row_refax = [None for i in range(self._nrows)]
        self._refax = None

        for i in range(self.ngrids):

            col, row = self._get_col_row(i)

            if share_all:
                sharex = self._refax
                sharey = self._refax
            else:
                sharex = self._column_refax[col]
                sharey = self._row_refax[row]

            ax = axes_class(fig, rect, sharex=sharex, sharey=sharey, **axes_class_args)

            if share_all:
                if self._refax is None:
                    self._refax = ax
            else:
                if sharex is None:
                    self._column_refax[col] = ax
                if sharey is None:
                    self._row_refax[row] = ax

            self.axes_all.append(ax)
            self.axes_column[col].append(ax)
            self.axes_row[row].append(ax)

            cax = self._defaultCbarAxesClass(fig, rect, orientation=self._colorbar_location)
            self.cbar_axes.append(cax)

        self.axes_llc = self.axes_column[0][-1]

        self._update_locators()

        if add_all:
            for ax in self.axes_all + self.cbar_axes:
                fig.add_axes(ax)

        if cbar_set_cax:
            if self._colorbar_mode == "single":
                for ax in self.axes_all:
                    ax.cax = self.cbar_axes[0]
            else:
                for ax, cax in zip(self.axes_all, self.cbar_axes):
                    ax.cax = cax

        self.set_label_mode(label_mode)
    def draw_animation_edges(G,
                             pos,
                             edgelist=None,
                             width=1.0,
                             edge_color='k',
                             style='solid',
                             alpha=1.0,
                             edge_cmap=None,
                             edge_vmin=None,
                             edge_vmax=None,
                             ax=None,
                             arrows=True,
                             label=None,
                             **kwds):
        try:
            import matplotlib
            import matplotlib.pyplot as plt
            import matplotlib.cbook as cb
            from matplotlib.colors import colorConverter, Colormap
            from matplotlib.collections import LineCollection
            import numpy
        except ImportError:
            raise ImportError("Matplotlib required for draw()")
        except RuntimeError:
            print("Matplotlib unable to open display")
            raise

        if ax is None:
            ax = plt.gca()

        if edgelist is None:
            edgelist = list(G.edges())

        if not edgelist or len(edgelist) == 0:  # no edges!
            return None

        # set edge positions

        box_pos = numpy.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])
        p = 0.25
        edge_pos = []
        for edge in edgelist:
            src, dst = np.array(pos[edge[0]]), np.array(pos[edge[1]])
            s = dst - src
            # src = src + p * s  # Box at beginning
            # dst = src + (1-p) * s   # Box at the end
            dst = src  # No edge at all
            edge_pos.append((src, dst))
        edge_pos = numpy.asarray(edge_pos)

        if not cb.iterable(width):
            lw = (width, )
        else:
            lw = width

        if not cb.is_scalar_or_string(edge_color) \
                and cb.iterable(edge_color) \
                and len(edge_color) == len(edge_pos):
            if numpy.alltrue([cb.is_scalar_or_string(c) for c in edge_color]):
                # (should check ALL elements)
                # list of color letters such as ['k','r','k',...]
                edge_colors = tuple(
                    [colorConverter.to_rgba(c, alpha) for c in edge_color])
            elif numpy.alltrue(
                [not cb.is_scalar_or_string(c) for c in edge_color]):
                # If color specs are given as (rgb) or (rgba) tuples, we're OK
                if numpy.alltrue(
                    [cb.iterable(c) and len(c) in (3, 4) for c in edge_color]):
                    edge_colors = tuple(edge_color)
                else:
                    # numbers (which are going to be mapped with a colormap)
                    edge_colors = None
            else:
                raise ValueError(
                    'edge_color must consist of either color names or numbers')
        else:
            if cb.is_scalar_or_string(edge_color) or len(edge_color) == 1:
                edge_colors = (colorConverter.to_rgba(edge_color, alpha), )
            else:
                raise ValueError(
                    'edge_color must be a single color or list of exactly m colors where m is the number or edges'
                )
        '''
        modEdgeColors = list(edge_colors)
        modEdgeColors = tuple(modEdgeColors + [colorConverter.to_rgba('w', alpha)
                                     for c in edge_color])
        #print(modEdgeColors)
        edge_collection = LineCollection(np.asarray(list(edge_pos)*2),
                                         colors=modEdgeColors,
                                         linewidths=[6]*len(list(edge_colors))+[4]*len(list(edge_colors)),
                                         antialiaseds=(1,),
                                         linestyle=style,
                                         transOffset=ax.transData,
                                         )
        '''
        edge_collection = LineCollection(
            edge_pos,
            colors=edge_colors,
            linewidths=6,
            antialiaseds=(1, ),
            linestyle=style,
            transOffset=ax.transData,
        )

        edge_collection.set_zorder(1)  # edges go behind nodes
        edge_collection.set_label(label)
        ax.add_collection(edge_collection)

        tube_collection = LineCollection(
            edge_pos,
            colors=tuple([
                colorConverter.to_rgba('lightgrey', alpha) for c in edge_color
            ]),
            linewidths=4,
            antialiaseds=(1, ),
            linestyle=style,
            transOffset=ax.transData,
        )

        tube_collection.set_zorder(1)  # edges go behind nodes
        tube_collection.set_label(label)
        ax.add_collection(tube_collection)
        # Note: there was a bug in mpl regarding the handling of alpha values for
        # each line in a LineCollection.  It was fixed in matplotlib in r7184 and
        # r7189 (June 6 2009).  We should then not set the alpha value globally,
        # since the user can instead provide per-edge alphas now.  Only set it
        # globally if provided as a scalar.
        if cb.is_numlike(alpha):
            edge_collection.set_alpha(alpha)

        if edge_colors is None:
            if edge_cmap is not None:
                assert (isinstance(edge_cmap, Colormap))
            edge_collection.set_array(numpy.asarray(edge_color))
            edge_collection.set_cmap(edge_cmap)
            if edge_vmin is not None or edge_vmax is not None:
                edge_collection.set_clim(edge_vmin, edge_vmax)
            else:
                edge_collection.autoscale()

        box_collection = Utilities.get_boxes(edge_colors=edge_colors,
                                             edge_pos=box_pos)
        box_collection.set_zorder(1)  # edges go behind nodes
        box_collection.set_label(label)
        ax.add_collection(box_collection)

        arrow_collection = Utilities.get_arrows_on_edges(
            edge_colors=edge_colors, edge_pos=box_pos)
        arrow_collection.set_zorder(0)

        if arrows:
            # Visualize them only if wanted
            ax.add_collection(arrow_collection)

        return edge_collection, box_collection, tube_collection, arrow_collection
示例#36
0
    def draw_edges(G,
                   segments,
                   pos=None,
                   edgelist=None,
                   width=1.0,
                   color='k',
                   style='solid',
                   alpha=None,
                   ax=None,
                   **kwds):
        """Draw the edges of the graph G.

        This draws the edge segments given by a separation of the links in
        `data` of the graph G.

        Parameters
        ----------
        G : graph
           A networkx graph

        segments : L x M array
           The segmentation of each link. (segments.sum(axis=1) == 1).all()

        pos : dictionary
           A dictionary with nodes as keys and positions as values.
           Positions should be sequences of length 2.
           (default=nx.get_node_attributes(G, 'pos'))

        edgelist : collection of edge tuples
           Draw only specified edges(default=G.edges())

        width : float or array of floats
           Line width of edges (default =1.0)

        color : tuple of color strings
           Edge Segments color. Can be a single color format string (default='r'),
           or a sequence of colors with the same length as data.shape[1].

        style : string
           Edge line style (default='solid') (solid|dashed|dotted,dashdot)

        alpha : float
           The edge transparency (default=1.0)

        ax : Matplotlib Axes object, optional
           Draw the graph in the specified Matplotlib axes.

        Returns
        -------
        matplotlib.collection.LineCollection
            `LineCollection` of the edge segments

        """
        if not np.allclose(segments.sum(axis=1), 1):
            segments = segments / segments.sum(axis=1, keepdims=True)

        if ax is None:
            ax = plt.gca()

        if pos is None:
            pos = nx.get_node_attributes(G, 'pos')

        if edgelist is None:
            edgelist = G.edges()

        if not edgelist or len(edgelist) == 0:  # no edges!
            return None

        if not cb.iterable(width):
            lw = (width, )
        else:
            lw = width

        if cb.iterable(color) \
               and len(color) == segments.shape[1]:
            if np.alltrue([cb.is_string_like(c) for c in color]):
                # (should check ALL elements)
                # list of color letters such as ['k','r','k',...]
                edge_colors = tuple(
                    [colorConverter.to_rgba(c, alpha) for c in color])
            elif (np.alltrue([not cb.is_string_like(c) for c in color])
                  and np.alltrue(
                      [cb.iterable(c) and len(c) in (3, 4) for c in color])):
                edge_colors = tuple(color)
            else:
                raise ValueError(
                    'color must consist of either color names or numbers')
        else:
            if cb.is_string_like(color) or len(color) == 1:
                edge_colors = (colorConverter.to_rgba(edge_color, alpha), )
            else:
                raise ValueError(
                    'color must be a single color or list of exactly m colors where m is the number of segments'
                )

        assert len(edgelist) == segments.shape[
            0], "Number edges and segments have to line up"

        # set edge positions
        edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])

        src = edge_pos[:, 0]
        dest = edge_pos[:, 1]

        positions = src[:, np.newaxis] + np.cumsum(
            np.hstack((np.zeros((len(segments), 1)), segments)),
            axis=1)[:, :, np.newaxis] * (dest - src)[:, np.newaxis]

        linecolls = []
        for s in range(segments.shape[1]):
            coll = LineCollection(positions[:, s:s + 2],
                                  colors=edge_colors[s:s + 1],
                                  linewidths=lw,
                                  antialiaseds=(1, ),
                                  linestyle=style,
                                  transOffset=ax.transData)

            coll.set_zorder(1)  # edges go behind nodes
            # coll.set_label(label)

            if cb.is_numlike(alpha):
                coll.set_alpha(alpha)

            ax.add_collection(coll)
            linecolls.append(coll)

        # update view
        minx = np.amin(np.ravel(edge_pos[:, :, 0]))
        maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
        miny = np.amin(np.ravel(edge_pos[:, :, 1]))
        maxy = np.amax(np.ravel(edge_pos[:, :, 1]))

        w = maxx - minx
        h = maxy - miny
        padx, pady = 0.05 * w, 0.05 * h
        corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
        ax.update_datalim(corners)
        ax.autoscale_view()

        return linecolls
示例#37
0
    def _calculate(self, data):
        x = data.pop('x')
        breaks = self.params['breaks']
        right = self.params['right']
        binwidth = self.params['binwidth']

        # y values are not needed
        try:
            del data['y']
        except KeyError:
            pass
        else:
            self._print_warning(_MSG_YVALUE)

        # If weight not mapped to, use one (no weight)
        try:
            weights = data.pop('weight')
        except KeyError:
            weights = np.ones(len(x))
        else:
            weights = make_iterable_ntimes(weights, len(x))

        categorical = is_categorical(x.values)
        if categorical:
            x_assignments = x
            x = sorted(set(x))
            width = make_iterable_ntimes(self.params['width'], len(x))
        elif cbook.is_numlike(x.iloc[0]):
            if breaks is None and binwidth is None:
                _bin_count = 30
                self._print_warning(_MSG_BINWIDTH)
            if binwidth:
                _bin_count = int(np.ceil(np.ptp(x))) / binwidth

            # Breaks have a higher precedence and,
            # pandas accepts either the breaks or the number of bins
            _bins_info = breaks or _bin_count
            x_assignments, breaks = pd.cut(x, bins=_bins_info, labels=False,
                                           right=right, retbins=True)
            width = np.diff(breaks)
            x = [breaks[i] + width[i] / 2
                 for i in range(len(breaks)-1)]
        else:
            raise GgplotError("Cannot recognise the type of x")

        # Create a dataframe with two columns:
        #   - the bins to which each x is assigned
        #   - the weights of each x value
        # Then create a weighted frequency table
        _df = pd.DataFrame({'assignments': x_assignments,
                            'weights': weights
                            })
        _wfreq_table = pd.pivot_table(_df, values='weights',
                                      rows=['assignments'], aggfunc=np.sum)

        # For numerical x values, empty bins get have no value
        # in the computed frequency table. We need to add the zeros and
        # since frequency table is a Series object, we need to keep it ordered
        if len(_wfreq_table) < len(x):
            empty_bins = set(range(len(x))) - set(x_assignments)
            for _b in empty_bins:
                _wfreq_table[_b] = 0
            _wfreq_table = _wfreq_table.sort_index()

        y = list(_wfreq_table)
        new_data = pd.DataFrame({'x': x, 'y': y, 'width': width})

        # Copy the other aesthetics into the new dataframe
        n = len(x)
        for ae in data:
            new_data[ae] = make_iterable_ntimes(data[ae].iloc[0], n)
        return new_data
示例#38
0
def draw_networkx_edges(G, pos,
                        edgelist=None,
                        width=1.0,
                        edge_color='k',
                        style='solid',
                        alpha=1.0,
                        edge_cmap=None,
                        edge_vmin=None,
                        edge_vmax=None,
                        ax=None,
                        arrows=True,
                        label=None,
                        **kwds):
    """Draw the edges of the graph G.

    This draws only the edges of the graph G.

    Parameters
    ----------
    G : graph
       A networkx graph

    pos : dictionary
       A dictionary with nodes as keys and positions as values.
       Positions should be sequences of length 2.

    edgelist : collection of edge tuples
       Draw only specified edges(default=G.edges())

    width : float, or array of floats
       Line width of edges (default=1.0)

    edge_color : color string, or array of floats
       Edge color. Can be a single color format string (default='r'),
       or a sequence of colors with the same length as edgelist.
       If numeric values are specified they will be mapped to
       colors using the edge_cmap and edge_vmin,edge_vmax parameters.

    style : string
       Edge line style (default='solid') (solid|dashed|dotted,dashdot)

    alpha : float
       The edge transparency (default=1.0)

    edge_ cmap : Matplotlib colormap
       Colormap for mapping intensities of edges (default=None)

    edge_vmin,edge_vmax : floats
       Minimum and maximum for edge colormap scaling (default=None)

    ax : Matplotlib Axes object, optional
       Draw the graph in the specified Matplotlib axes.

    arrows : bool, optional (default=True)
       For directed graphs, if True draw arrowheads.

    label : [None| string]
       Label for legend

    Returns
    -------
    matplotlib.collection.LineCollection
        `LineCollection` of the edges

    Notes
    -----
    For directed graphs, "arrows" (actually just thicker stubs) are drawn
    at the head end.  Arrows can be turned off with keyword arrows=False.
    Yes, it is ugly but drawing proper arrows with Matplotlib this
    way is tricky.

    Examples
    --------
    >>> G=nx.dodecahedral_graph()
    >>> edges=nx.draw_networkx_edges(G,pos=nx.spring_layout(G))

    Also see the NetworkX drawing examples at
    http://networkx.github.io/documentation/latest/gallery.html

    See Also
    --------
    draw()
    draw_networkx()
    draw_networkx_nodes()
    draw_networkx_labels()
    draw_networkx_edge_labels()
    """
    try:
        import matplotlib
        import matplotlib.pyplot as plt
        import matplotlib.cbook as cb
        from matplotlib.colors import colorConverter, Colormap
        from matplotlib.collections import LineCollection
        import numpy
    except ImportError:
        raise ImportError("Matplotlib required for draw()")
    except RuntimeError:
        print("Matplotlib unable to open display")
        raise

    if ax is None:
        ax = plt.gca()

    if edgelist is None:
        edgelist = list(G.edges())

    if not edgelist or len(edgelist) == 0:  # no edges!
        return None

    # set edge positions
    edge_pos = numpy.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])

    if not cb.iterable(width):
        lw = (width,)
    else:
        lw = width

    if not cb.is_string_like(edge_color) \
           and cb.iterable(edge_color) \
           and len(edge_color) == len(edge_pos):
        if numpy.alltrue([cb.is_string_like(c)
                         for c in edge_color]):
            # (should check ALL elements)
            # list of color letters such as ['k','r','k',...]
            edge_colors = tuple([colorConverter.to_rgba(c, alpha)
                                 for c in edge_color])
        elif numpy.alltrue([not cb.is_string_like(c)
                           for c in edge_color]):
            # If color specs are given as (rgb) or (rgba) tuples, we're OK
            if numpy.alltrue([cb.iterable(c) and len(c) in (3, 4)
                             for c in edge_color]):
                edge_colors = tuple(edge_color)
            else:
                # numbers (which are going to be mapped with a colormap)
                edge_colors = None
        else:
            raise ValueError('edge_color must consist of either color names or numbers')
    else:
        if cb.is_string_like(edge_color) or len(edge_color) == 1:
            edge_colors = (colorConverter.to_rgba(edge_color, alpha), )
        else:
            raise ValueError('edge_color must be a single color or list of exactly m colors where m is the number or edges')

    edge_collection = LineCollection(edge_pos,
                                     colors=edge_colors,
                                     linewidths=lw,
                                     antialiaseds=(1,),
                                     linestyle=style,
                                     transOffset = ax.transData,
                                     )

    edge_collection.set_zorder(1)  # edges go behind nodes
    edge_collection.set_label(label)
    ax.add_collection(edge_collection)

    # Note: there was a bug in mpl regarding the handling of alpha values for
    # each line in a LineCollection.  It was fixed in matplotlib in r7184 and
    # r7189 (June 6 2009).  We should then not set the alpha value globally,
    # since the user can instead provide per-edge alphas now.  Only set it
    # globally if provided as a scalar.
    if cb.is_numlike(alpha):
        edge_collection.set_alpha(alpha)

    if edge_colors is None:
        if edge_cmap is not None:
            assert(isinstance(edge_cmap, Colormap))
        edge_collection.set_array(numpy.asarray(edge_color))
        edge_collection.set_cmap(edge_cmap)
        if edge_vmin is not None or edge_vmax is not None:
            edge_collection.set_clim(edge_vmin, edge_vmax)
        else:
            edge_collection.autoscale()

    arrow_collection = None

    if G.is_directed() and arrows:

        # a directed graph hack
        # draw thick line segments at head end of edge
        # waiting for someone else to implement arrows that will work
        arrow_colors = edge_colors
        a_pos = []
        p = 1.0-0.25  # make head segment 25 percent of edge length
        for src, dst in edge_pos:
            x1, y1 = src
            x2, y2 = dst
            dx = x2-x1   # x offset
            dy = y2-y1   # y offset
            d = numpy.sqrt(float(dx**2 + dy**2))  # length of edge
            if d == 0:   # source and target at same position
                continue
            if dx == 0:  # vertical edge
                xa = x2
                ya = dy*p+y1
            if dy == 0:  # horizontal edge
                ya = y2
                xa = dx*p+x1
            else:
                theta = numpy.arctan2(dy, dx)
                xa = p*d*numpy.cos(theta)+x1
                ya = p*d*numpy.sin(theta)+y1

            a_pos.append(((xa, ya), (x2, y2)))

        arrow_collection = LineCollection(a_pos,
                                colors=arrow_colors,
                                linewidths=[4*ww for ww in lw],
                                antialiaseds=(1,),
                                transOffset = ax.transData,
                                )

        arrow_collection.set_zorder(1)  # edges go behind nodes
        arrow_collection.set_label(label)
        ax.add_collection(arrow_collection)

    # update view
    minx = numpy.amin(numpy.ravel(edge_pos[:, :, 0]))
    maxx = numpy.amax(numpy.ravel(edge_pos[:, :, 0]))
    miny = numpy.amin(numpy.ravel(edge_pos[:, :, 1]))
    maxy = numpy.amax(numpy.ravel(edge_pos[:, :, 1]))

    w = maxx-minx
    h = maxy-miny
    padx,  pady = 0.05*w, 0.05*h
    corners = (minx-padx, miny-pady), (maxx+padx, maxy+pady)
    ax.update_datalim(corners)
    ax.autoscale_view()

#    if arrow_collection:

    return edge_collection
示例#39
0
def draw_networkx_edges(G, pos,
                        edgelist=None,
                        width=1.0,
                        edge_color='k',
                        style='solid',
                        alpha=1.0,
                        arrowstyle='-|>',
                        arrowsize=10,
                        edge_cmap=None,
                        edge_vmin=None,
                        edge_vmax=None,
                        ax=None,
                        arrows=True,
                        label=None,
                        node_size=300,
                        nodelist=None,
                        node_shape="o",
                        **kwds):
    """Draw the edges of the graph G.

    This draws only the edges of the graph G.

    Parameters
    ----------
    G : graph
       A networkx graph

    pos : dictionary
       A dictionary with nodes as keys and positions as values.
       Positions should be sequences of length 2.

    edgelist : collection of edge tuples
       Draw only specified edges(default=G.edges())

    width : float, or array of floats
       Line width of edges (default=1.0)

    edge_color : color string, or array of floats
       Edge color. Can be a single color format string (default='r'),
       or a sequence of colors with the same length as edgelist.
       If numeric values are specified they will be mapped to
       colors using the edge_cmap and edge_vmin,edge_vmax parameters.

    style : string
       Edge line style (default='solid') (solid|dashed|dotted,dashdot)

    alpha : float
       The edge transparency (default=1.0)

    edge_ cmap : Matplotlib colormap
       Colormap for mapping intensities of edges (default=None)

    edge_vmin,edge_vmax : floats
       Minimum and maximum for edge colormap scaling (default=None)

    ax : Matplotlib Axes object, optional
       Draw the graph in the specified Matplotlib axes.

    arrows : bool, optional (default=True)
       For directed graphs, if True draw arrowheads.
       Note: Arrows will be the same color as edges.

    arrowstyle : str, optional (default='-|>')
       For directed graphs, choose the style of the arrow heads.
       See :py:class: `matplotlib.patches.ArrowStyle` for more
       options.

    arrowsize : int, optional (default=10)
       For directed graphs, choose the size of the arrow head head's length and
       width. See :py:class: `matplotlib.patches.FancyArrowPatch` for attribute
       `mutation_scale` for more info.

    label : [None| string]
       Label for legend

    Returns
    -------
    matplotlib.collection.LineCollection
        `LineCollection` of the edges

    list of matplotlib.patches.FancyArrowPatch
        `FancyArrowPatch` instances of the directed edges

    Depending whether the drawing includes arrows or not.

    Notes
    -----
    For directed graphs, arrows are drawn at the head end.  Arrows can be
    turned off with keyword arrows=False. Be sure to include `node_size' as a
    keyword argument; arrows are drawn considering the size of nodes.

    Examples
    --------
    >>> G = nx.dodecahedral_graph()
    >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))

    >>> G = nx.DiGraph()
    >>> G.add_edges_from([(1, 2), (1, 3), (2, 3)])
    >>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))
    >>> alphas = [0.3, 0.4, 0.5]
    >>> for i, arc in enumerate(arcs):  # change alpha values of arcs
    ...     arc.set_alpha(alphas[i])

    Also see the NetworkX drawing examples at
    https://networkx.github.io/documentation/latest/auto_examples/index.html

    See Also
    --------
    draw()
    draw_networkx()
    draw_networkx_nodes()
    draw_networkx_labels()
    draw_networkx_edge_labels()
    """
    try:
        import matplotlib
        import matplotlib.pyplot as plt
        import matplotlib.cbook as cb
        from matplotlib.colors import colorConverter, Colormap, Normalize
        from matplotlib.collections import LineCollection
        from matplotlib.patches import FancyArrowPatch
        import numpy as np
    except ImportError:
        raise ImportError("Matplotlib required for draw()")
    except RuntimeError:
        print("Matplotlib unable to open display")
        raise

    if ax is None:
        ax = plt.gca()

    if edgelist is None:
        edgelist = list(G.edges())

    if not edgelist or len(edgelist) == 0:  # no edges!
        return None

    if nodelist is None:
        nodelist = list(G.nodes())

    # set edge positions
    edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])

    if not cb.iterable(width):
        lw = (width,)
    else:
        lw = width

    if not is_string_like(edge_color) \
            and cb.iterable(edge_color) \
            and len(edge_color) == len(edge_pos):
        if np.alltrue([is_string_like(c) for c in edge_color]):
            # (should check ALL elements)
            # list of color letters such as ['k','r','k',...]
            edge_colors = tuple([colorConverter.to_rgba(c, alpha)
                                 for c in edge_color])
        elif np.alltrue([not is_string_like(c) for c in edge_color]):
            # If color specs are given as (rgb) or (rgba) tuples, we're OK
            if np.alltrue([cb.iterable(c) and len(c) in (3, 4)
                          for c in edge_color]):
                edge_colors = tuple(edge_color)
            else:
                # numbers (which are going to be mapped with a colormap)
                edge_colors = None
        else:
            raise ValueError('edge_color must contain color names or numbers')
    else:
        if is_string_like(edge_color) or len(edge_color) == 1:
            edge_colors = (colorConverter.to_rgba(edge_color, alpha), )
        else:
            msg = 'edge_color must be a color or list of one color per edge'
            raise ValueError(msg)

    if (not G.is_directed() or not arrows):
        edge_collection = LineCollection(edge_pos,
                                         colors=edge_colors,
                                         linewidths=lw,
                                         antialiaseds=(1,),
                                         linestyle=style,
                                         transOffset=ax.transData,
                                         )

        edge_collection.set_zorder(1)  # edges go behind nodes
        edge_collection.set_label(label)
        ax.add_collection(edge_collection)

        # Note: there was a bug in mpl regarding the handling of alpha values
        # for each line in a LineCollection. It was fixed in matplotlib by
        # r7184 and r7189 (June 6 2009). We should then not set the alpha
        # value globally, since the user can instead provide per-edge alphas
        # now.  Only set it globally if provided as a scalar.
        if cb.is_numlike(alpha):
            edge_collection.set_alpha(alpha)

        if edge_colors is None:
            if edge_cmap is not None:
                assert(isinstance(edge_cmap, Colormap))
            edge_collection.set_array(np.asarray(edge_color))
            edge_collection.set_cmap(edge_cmap)
            if edge_vmin is not None or edge_vmax is not None:
                edge_collection.set_clim(edge_vmin, edge_vmax)
            else:
                edge_collection.autoscale()
        return edge_collection

    arrow_collection = None

    if G.is_directed() and arrows:
        # Note: Waiting for someone to implement arrow to intersection with
        # marker.  Meanwhile, this works well for polygons with more than 4
        # sides and circle.

        def to_marker_edge(marker_size, marker):
            if marker in "s^>v<d":  # `large` markers need extra space
                return np.sqrt(2 * marker_size) / 2
            else:
                return np.sqrt(marker_size) / 2

        # Draw arrows with `matplotlib.patches.FancyarrowPatch`
        arrow_collection = []
        mutation_scale = arrowsize  # scale factor of arrow head
        arrow_colors = edge_colors
        if arrow_colors is None:
            if edge_cmap is not None:
                assert(isinstance(edge_cmap, Colormap))
            else:
                edge_cmap = plt.get_cmap()  # default matplotlib colormap
            if edge_vmin is None:
                edge_vmin = min(edge_color)
            if edge_vmax is None:
                edge_vmax = max(edge_color)
            color_normal = Normalize(vmin=edge_vmin, vmax=edge_vmax)

        for i, (src, dst) in enumerate(edge_pos):
            x1, y1 = src
            x2, y2 = dst
            arrow_color = None
            line_width = None
            shrink_source = 0  # space from source to tail
            shrink_target = 0  # space from  head to target
            if cb.iterable(node_size):  # many node sizes
                src_node, dst_node = edgelist[i]
                index_node = nodelist.index(dst_node)
                marker_size = node_size[index_node]
                shrink_target = to_marker_edge(marker_size, node_shape)
            else:
                shrink_target = to_marker_edge(node_size, node_shape)
            if arrow_colors is None:
                arrow_color = edge_cmap(color_normal(edge_color[i]))
            elif len(arrow_colors) > 1:
                arrow_color = arrow_colors[i]
            else:
                arrow_color = arrow_colors[0]
            if len(lw) > 1:
                line_width = lw[i]
            else:
                line_width = lw[0]
            arrow = FancyArrowPatch((x1, y1), (x2, y2),
                                    arrowstyle=arrowstyle,
                                    shrinkA=shrink_source,
                                    shrinkB=shrink_target,
                                    mutation_scale=mutation_scale,
                                    color=arrow_color,
                                    linewidth=line_width,
                                    zorder=1)  # arrows go behind nodes

            # There seems to be a bug in matplotlib to make collections of
            # FancyArrowPatch instances. Until fixed, the patches are added
            # individually to the axes instance.
            arrow_collection.append(arrow)
            ax.add_patch(arrow)

    # update view
    minx = np.amin(np.ravel(edge_pos[:, :, 0]))
    maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
    miny = np.amin(np.ravel(edge_pos[:, :, 1]))
    maxy = np.amax(np.ravel(edge_pos[:, :, 1]))

    w = maxx - minx
    h = maxy - miny
    padx,  pady = 0.05 * w, 0.05 * h
    corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
    ax.update_datalim(corners)
    ax.autoscale_view()

    return arrow_collection
示例#40
0
def draw_networkx_edges(G, pos,
                        edgelist=None,
                        width=1.0,
                        edge_color='k',
                        style='solid',
                        alpha=1.0,
                        edge_cmap=None,
                        edge_vmin=None,
                        edge_vmax=None,
                        ax=None,
                        arrows=True,
                        label=None,
                        **kwds):
    """Draw the edges of the graph G.

    This draws only the edges of the graph G.

    Parameters
    ----------
    G : graph
       A networkx graph

    pos : dictionary
       A dictionary with nodes as keys and positions as values.
       Positions should be sequences of length 2.

    edgelist : collection of edge tuples
       Draw only specified edges(default=G.edges())

    width : float, or array of floats
       Line width of edges (default=1.0)

    edge_color : color string, or array of floats
       Edge color. Can be a single color format string (default='r'),
       or a sequence of colors with the same length as edgelist.
       If numeric values are specified they will be mapped to
       colors using the edge_cmap and edge_vmin,edge_vmax parameters.

    style : string
       Edge line style (default='solid') (solid|dashed|dotted,dashdot)

    alpha : float
       The edge transparency (default=1.0)

    edge_ cmap : Matplotlib colormap
       Colormap for mapping intensities of edges (default=None)

    edge_vmin,edge_vmax : floats
       Minimum and maximum for edge colormap scaling (default=None)

    ax : Matplotlib Axes object, optional
       Draw the graph in the specified Matplotlib axes.

    arrows : bool, optional (default=True)
       For directed graphs, if True draw arrowheads.

    label : [None| string]
       Label for legend

    Returns
    -------
    matplotlib.collection.LineCollection
        `LineCollection` of the edges

    Notes
    -----
    For directed graphs, "arrows" (actually just thicker stubs) are drawn
    at the head end.  Arrows can be turned off with keyword arrows=False.
    Yes, it is ugly but drawing proper arrows with Matplotlib this
    way is tricky.

    Examples
    --------
    >>> G = nx.dodecahedral_graph()
    >>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G))

    Also see the NetworkX drawing examples at
    https://networkx.github.io/documentation/latest/auto_examples/index.html

    See Also
    --------
    draw()
    draw_networkx()
    draw_networkx_nodes()
    draw_networkx_labels()
    draw_networkx_edge_labels()
    """
    try:
        import matplotlib
        import matplotlib.pyplot as plt
        import matplotlib.cbook as cb
        from matplotlib.colors import colorConverter, Colormap
        from matplotlib.collections import LineCollection
        import numpy as np
    except ImportError:
        raise ImportError("Matplotlib required for draw()")
    except RuntimeError:
        print("Matplotlib unable to open display")
        raise

    if ax is None:
        ax = plt.gca()

    if edgelist is None:
        edgelist = list(G.edges())

    if not edgelist or len(edgelist) == 0:  # no edges!
        return None

    # set edge positions
    edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist])

    if not cb.iterable(width):
        lw = (width,)
    else:
        lw = width

    if not cb.is_string_like(edge_color) \
            and cb.iterable(edge_color) \
            and len(edge_color) == len(edge_pos):
        if np.alltrue([cb.is_string_like(c)
                      for c in edge_color]):
            # (should check ALL elements)
            # list of color letters such as ['k','r','k',...]
            edge_colors = tuple([colorConverter.to_rgba(c, alpha)
                                 for c in edge_color])
        elif np.alltrue([not cb.is_string_like(c)
                        for c in edge_color]):
            # If color specs are given as (rgb) or (rgba) tuples, we're OK
            if np.alltrue([cb.iterable(c) and len(c) in (3, 4)
                          for c in edge_color]):
                edge_colors = tuple(edge_color)
            else:
                # numbers (which are going to be mapped with a colormap)
                edge_colors = None
        else:
            raise ValueError('edge_color must consist of either color names or numbers')
    else:
        if cb.is_string_like(edge_color) or len(edge_color) == 1:
            edge_colors = (colorConverter.to_rgba(edge_color, alpha), )
        else:
            raise ValueError(
                'edge_color must be a single color or list of exactly m colors where m is the number or edges')

    edge_collection = LineCollection(edge_pos,
                                     colors=edge_colors,
                                     linewidths=lw,
                                     antialiaseds=(1,),
                                     linestyle=style,
                                     transOffset=ax.transData,
                                     )

    edge_collection.set_zorder(1)  # edges go behind nodes
    edge_collection.set_label(label)
    ax.add_collection(edge_collection)

    # Note: there was a bug in mpl regarding the handling of alpha values for
    # each line in a LineCollection.  It was fixed in matplotlib in r7184 and
    # r7189 (June 6 2009).  We should then not set the alpha value globally,
    # since the user can instead provide per-edge alphas now.  Only set it
    # globally if provided as a scalar.
    if cb.is_numlike(alpha):
        edge_collection.set_alpha(alpha)

    if edge_colors is None:
        if edge_cmap is not None:
            assert(isinstance(edge_cmap, Colormap))
        edge_collection.set_array(np.asarray(edge_color))
        edge_collection.set_cmap(edge_cmap)
        if edge_vmin is not None or edge_vmax is not None:
            edge_collection.set_clim(edge_vmin, edge_vmax)
        else:
            edge_collection.autoscale()

    arrow_collection = None

    if G.is_directed() and arrows:

        # a directed graph hack
        # draw thick line segments at head end of edge
        # waiting for someone else to implement arrows that will work
        arrow_colors = edge_colors
        a_pos = []
        p = 1.0 - 0.25  # make head segment 25 percent of edge length
        for src, dst in edge_pos:
            x1, y1 = src
            x2, y2 = dst
            dx = x2 - x1   # x offset
            dy = y2 - y1   # y offset
            d = np.sqrt(float(dx**2 + dy**2))  # length of edge
            if d == 0:   # source and target at same position
                continue
            if dx == 0:  # vertical edge
                xa = x2
                ya = dy * p + y1
            if dy == 0:  # horizontal edge
                ya = y2
                xa = dx * p + x1
            else:
                theta = np.arctan2(dy, dx)
                xa = p * d * np.cos(theta) + x1
                ya = p * d * np.sin(theta) + y1

            a_pos.append(((xa, ya), (x2, y2)))

        arrow_collection = LineCollection(a_pos,
                                          colors=arrow_colors,
                                          linewidths=[4 * ww for ww in lw],
                                          antialiaseds=(1,),
                                          transOffset=ax.transData,
                                          )

        arrow_collection.set_zorder(1)  # edges go behind nodes
        arrow_collection.set_label(label)
        ax.add_collection(arrow_collection)

    # update view
    minx = np.amin(np.ravel(edge_pos[:, :, 0]))
    maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
    miny = np.amin(np.ravel(edge_pos[:, :, 1]))
    maxy = np.amax(np.ravel(edge_pos[:, :, 1]))

    w = maxx - minx
    h = maxy - miny
    padx,  pady = 0.05 * w, 0.05 * h
    corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
    ax.update_datalim(corners)
    ax.autoscale_view()

#    if arrow_collection:

    return edge_collection
示例#41
0
def run():

    # Which timing of plot: 'weatherband[1-3]', 'seasonal', 'interannual-winter', 'interannual-summer'
    # 'interannual-01' through 'interannual-12', 'monthly-2004' through 'monthly-2014'
    # 'interannual-mean' 'monthly-mean'
    whichtime = 'seasonal'
    # Which type of plot: 'cross', 'coastCH', 'coastMX', 'coastLA', 
    #  'coastNTX', 'coastSTX', 'fsle', 'D2'
    whichtype = 'cross'
    # 'forward' or 'back' in time
    whichdir = 'forward'

    # if doing D2, choose the initial separation distance too. And give a little extra for roundoff and projections
    # Also, now given as an interval.
    r = [4.75, 5.25] # [0.95, 1.05] # [4.75, 5.25] # 5.25 and 1.05

    #levels = np.linspace(0,1,11)

    shelf_depth = 100 # do 100 50 and 20 
    ishelf_depth = 2 # 2 1 0 index in cross array
    numdays = 15  # 30. Number of analysis days to consider

    # Whether to overlay previously-calculated wind stress arrows
    # from projects/txla_plots/plot_mean_wind.py on Rainier
    addwind = 0  # for adding the wind on
    years = np.arange(2004,2015) # this is just for the wind I think

    # Number of bins to use in histogram
    bins = (100,100) #(30,30)

    # Load in Files to read from based on which type of plot is being run
    Files, cmap = init(whichtime, whichtype, whichdir)

    # Grid info
    loc = 'http://barataria.tamu.edu:8080/thredds/dodsC/NcML/txla_nesting6.nc'
    # loc = '/Users/kthyng/Documents/research/postdoc/grid.nc'
    grid = tracpy.inout.readgrid(loc, usebasemap=True)
    # grid_filename = '/atch/raid1/zhangxq/Projects/txla_nesting6/txla_grd_v4_new.nc'
    # vert_filename='/atch/raid1/zhangxq/Projects/txla_nesting6/ocean_his_0001.nc'
    # grid = tracpy.inout.readgrid(grid_filename, vert_filename=vert_filename, usebasemap=True)
    # grid_filename = '../../grid.nc'
    # grid = tracpy.inout.readgrid(grid_filename, usebasemap=True)

    if whichtype == 'D2':
        Hfilename = 'figures/' + whichtype + '/r' + str(int(r[1])) + '/' + whichtime + 'H.npz'
    elif 'coast' in whichtype:
        Hfilename = 'figures/' + whichtype + '/' + whichtime + 'H.npz'
    elif whichtype == 'cross':
        Hfilename = 'figures/' + whichtype + '/' + whichtime + str(shelf_depth) + 'advection' + str(numdays) + 'days-' + 'H.npz'
        # Hfilename = 'calcs/shelfconn/' + whichtime + str(shelf_depth) + 'H.npz'

    if not os.path.exists(Hfilename): 

        ## Calculate starting position histogram just once ##
        # Read in connectivity info (previously calculated). 
        # Drifters always start in the same place.
        # pdb.set_trace()
        if (whichtype == 'cross') or ('coast' in whichtype and whichdir == 'forward'):
            d = np.load(Files[0][0])

            # Calculate xrange and yrange for histograms
            Xrange = [grid['xpsi'].min(), grid['xpsi'].max()]
            Yrange = [grid['ypsi'].min(), grid['ypsi'].max()]

            # Histogram of starting locations
            if whichtype == 'cross': # results are in xg, yg
                xp0, yp0, _ = tracpy.tools.interpolate2d(d['xg0'], d['yg0'], grid, 'm_ij2xy')
            elif 'coast' in whichtype:  # results are in xp, yp
                xp0 = d['xp0']; yp0 = d['yp0']

            d.close()

        elif ('coast' in whichtype and whichdir == 'back'):

            # Calculate xrange and yrange for histograms
            Xrange = [grid['xpsi'].min(), grid['xpsi'].max()]
            Yrange = [grid['ypsi'].min(), grid['ypsi'].max()]

        elif whichtype == 'D2': # results are in xg, yg

            # Histogram of starting locations
            Hstartfile = 'calcs/dispersion/hist/Hstart_bins' + str(bins[0]) + '.npz'

            if not os.path.exists(Hstartfile): # just read in info
                d = netCDF.Dataset(Files[0][0])

                # Calculate xrange and yrange for histograms in ll
                Xrange = [grid['lonpsi'].min(), grid['lonpsi'].max()]
                Yrange = [grid['latpsi'].min(), grid['latpsi'].max()]
            
                # Histogram of starting locations
                # xp0, yp0 are lonp, latp in this case
                xp0, yp0, _ = tracpy.tools.interpolate2d(d.variables['xg'][:,0], d.variables['yg'][:,0], grid, 'm_ij2ll')

                d.close()

        if 'Hstartfile' in locals() and os.path.exists(Hstartfile): # just read in info
            Hstartf = np.load(Hstartfile)
            xe = Hstartf['xe']; ye = Hstartf['ye']
            Hstart = Hstartf['Hstart']
        # elif 'coast' in whichtype and whichdir == 'back':
        #     continue
        else:
            if whichdir == 'back': # aren't dividing at the end by the starting number
                Hstart = np.ones(bins)
            else:
                # For D2 and fsle, Hstart contains indices of drifters seeded in bins
                Hstart, xe, ye = calc_histogram(xp0, yp0, whichtype, bins=bins, Xrange=Xrange, Yrange=Yrange)

                if whichtype == 'D2':
                    xe, ye = grid['basemap'](xe, ye) # change from lon/lat
                    np.savez(Hstartfile, Hstart=Hstart, xe=xe, ye=ye) 

        # For D2 and fsle, H contains the metric calculation averaged over that bin
        if whichtype == 'D2' or whichtype == 'fsle':
            H = np.zeros((len(Files), Hstart.shape[0], Hstart.shape[1], 901))
            nnans = np.zeros((len(Files), Hstart.shape[0], Hstart.shape[1], 901))
        elif ('coast' in whichtype) or (whichtype=='cross'):
            H = np.zeros((len(Files), Hstart.shape[0], Hstart.shape[1]))
            nnans = np.zeros((len(Files), Hstart.shape[0], Hstart.shape[1]))

        # Loop through calculation files to calculate overall histograms
        # pdb.set_trace()
        for i, files in enumerate(Files): # Files has multiple entries, 1 for each subplot

            if whichtype == 'cross' or 'coast' in whichtype:
                Hcross = np.zeros(bins) # initialize
                #pdb.set_trace()
                HstartUse = Hstart*len(files) # multiply to account for each simulation


            for File in files: # now loop through the files for this subplot
                print File

                if whichtype == 'cross': # results are in xg, yg
                # [number of depths,number of tracks] to store time of crossing or nan if it doesn't cross
                    # Read in info
                    d = np.load(File)
                    xg0 = d['xg0']; yg0 = d['yg0']
                    cross = d['cross']
                    ind = (~np.isnan(cross[ishelf_depth,:])) * (cross[ishelf_depth, :] < numdays)
                    # ind = ~np.isnan(cross[ishelf_depth,:])  # this is for 30 days (the whole run time)
                    d.close()
                    xp, yp, _ = tracpy.tools.interpolate2d(xg0[ind], yg0[ind], grid, 'm_ij2xy')
                elif ('coast' in whichtype) and (whichdir=='forward'):  # results are in xp, yp
                    # Read in info
                    d = np.load(File)
                    xp = d['xp0']; yp = d['yp0']
                    conn = d['conn'] 
                    ind = ~np.isnan(conn)
                    xp = xp[ind]; yp = yp[ind] # pick out the drifters that reach the coastline
                    d.close()
                elif ('coast' in whichtype) and (whichdir=='back'):  # results are in xp, yp
                    # Read in info
                    d = np.load(File)
                    # backward tracks in time of drifters starting in the specified coastal area
                    xp = d['xp0']; yp = d['yp0']
                    conn = d['conn'] # indices of the drifters that started in the zone
                    # can't send nan's to histogram calculation
                    ind = ~np.isnan(xp)
                    xp = xp[ind]; yp = yp[ind] # pick out the drifters that reach the coastline
                    d.close()
                elif whichtype == 'D2' or whichtype == 'fsle':
                    sfile = 'calcs/dispersion/hist/D2/r' + str(int(r[1])) + '/' + File.split('/')[-1][:-5] + '_bins' + str(bins[0]) + '.npz'
                    if os.path.exists(sfile): # just read in info
                        already_calculated = 1
                    else:
                        already_calculated = 0
                    # # This is for distributing the workload to different processors
                    # if not '2004' in File:
                    #     continue

                    if not already_calculated:
                        print 'working on D2 ', sfile
                        d = netCDF.Dataset(File)
                        xg = d.variables['xg'][:]; yg = d.variables['yg'][:]
                        # eliminate entries equal to -1
                        ind = xg==-1
                        xg[ind] = np.nan; yg[ind] = np.nan
                        xp, yp, _ = tracpy.tools.interpolate2d(xg, yg, grid, 'm_ij2ll')
                        d.close()

                # Count the drifters for the shelf_depth that have a non-nan entry
                # pdb.set_trace() 
                if whichtype == 'cross' or 'coast' in whichtype:
                    # Calculate and accumulate histograms of starting locations of drifters that cross shelf
                    Hcrosstemp, xe, ye = calc_histogram(xp, yp, whichtype, bins=bins, Xrange=Xrange, Yrange=Yrange)
                    Hcross = np.nansum( np.vstack((Hcross[np.newaxis,:,:], Hcrosstemp[np.newaxis,:,:])), axis=0)
                elif whichtype == 'D2' or whichtype == 'fsle':
                    if not already_calculated:
                        # Calculate the metric in each bin and combine for all files
                        metric_temp, nnanstemp = calc_metric(xp, yp, Hstart, whichtype, r=r)
                        # Save calculations by bins for each file
                        # pdb.set_trace()
                        np.savez(sfile, D2=metric_temp, nnans=nnanstemp) 
                        print 'saving D2 file ', sfile
                        # metric_temp is in time, but want to show a single value for each bin in space.
                        # Take the value at the final time.
                        # pdb.set_trace()
                    else:
                        d = np.load(sfile)
                        metric_temp = d['D2']; nnanstemp = d['nnans']
                        # pdb.set_trace()
                        # filter out boxes with too few available drifters
                        ind = nnanstemp<30
                        metric_temp[ind] = np.nan

                    H[i,:] = np.nansum( np.vstack((H[np.newaxis,i,:,:,:],metric_temp[np.newaxis,:,:,:]*nnanstemp[np.newaxis,:,:,:])), axis=0) # need to un-average before combining
                    # H[i,:] = H[i,:] + metric_temp[:,:,-1]*nnanstemp[:,:,-1] # need to un-average before combining
                    nnans[i,:] = nnans[i,:] + nnanstemp[:,:,:] # need to un-average before combining

            # Calculate overall histogram
            if whichtype == 'cross' or 'coast' in whichtype:
                H[i,:] = (Hcross/HstartUse)*100
            elif whichtype == 'D2':
                # xe, ye = grid['basemap'](xe, ye) # change from lon/lat
                H[i,:] = H[i,:]/nnans[i,:]
                # np.savez('calcs/dispersion/hist/' + File.split('/')[-1][:-5] + '_bins' + str(bins[0])) 
            elif whichtype == 'fsle':
                H[i,:] = 1./H[i,:]/nnans[i,:]

        # save H
        if not os.path.exists('figures/' + whichtype): 
            os.makedirs('figures/' + whichtype)

        np.savez(Hfilename, H=H, xe=xe, ye=ye)

    else: # H has already been calculated

        Hfile = np.load(Hfilename)
        H = Hfile['H']; xe = Hfile['xe']; ye = Hfile['ye']
        #levels = np.linspace(0, np.nanmax(H), 11)

    # which time index to plot
    # a number or 'mean' or 'none' (for coast and cross)
    if whichtype == 'D2':
        itind = 100 # 100 
    elif (whichtype == 'cross') or ('coast' in whichtype):
        itind = 'none'
    # Choose consistent levels to plot
    locator = ticker.MaxNLocator(11)
    locator.create_dummy_axis()
    # don't use highest max since everything is washed out then
    # pdb.set_trace()
    # 12000 for mean interannual-summer, 20000 for mean, interannual-winter, 1400 for 100 seasonal
    # 1800 for 100 interannual-winter, 1800 for 100 interannual-summer
    if whichtype == 'D2':
        if itind == 30:
            locator.set_bounds(0, 10)
        elif itind == 100: 
            locator.set_bounds(0, 160) 
        elif itind == 150: 
            locator.set_bounds(0, 450) 
        elif itind == 300: 
            locator.set_bounds(0, 2200) 
        elif itind == 600: 
            locator.set_bounds(0, 8000) 
        elif itind == 900: 
            locator.set_bounds(0, 15000) 
        # locator.set_bounds(0, 0.2*np.nanmax(H[:,:,:,itind]))
        #locator.set_bounds(0, 0.75*np.nanmax(np.nanmax(H[:,:,:,itind], axis=1), axis=1).mean())
        levels = locator()
    elif 'coast' in whichtype and whichdir == 'back':
        hist, bin_edges = np.histogram(H.flat, bins=100) # find # of occurrences of histogram bin values
        n = np.cumsum(hist)
        Hmax = bin_edges[find(n<(n.max()-n.min())*.7+n.min())[-1]] # take the 80% of histogram occurrences as the max instead of actual max since too high
        locator.set_bounds(0, 1) 
        levels = locator()
        extend = 'max'
        H = H/Hmax
    else:
        extend = 'neither'


    # Set up overall plot, now that everything is calculated
    fig, axarr = plot_setup(whichtime, grid) # depends on which plot we're doing

    # Loop through calculation files to calculate overall histograms
    # pdb.set_trace()
    for i in xrange(H.shape[0]): # Files has multiple entries, 1 for each subplot

        # Do subplot
        # pdb.set_trace()
        # which time index to plot?
        #itind = 100
        if cbook.is_numlike(itind): # plot a particular time
            mappable = plot_stuff(xe, ye, H[i,:,:,itind], cmap, grid, shelf_depth, axarr.flatten()[i], levels=levels)
        elif itind=='mean': # plot the mean over time
            mappable = plot_stuff(xe, ye, np.nansum(H[i,:,:,:], axis=-1)/np.sum(~np.isnan(H[i,:,:,:]), axis=-1), cmap, grid, shelf_depth, axarr.flatten()[i], levels=levels)
        elif itind=='none': # just plot what is there
            if 'levels' in locals():
                mappable = plot_stuff(xe, ye, H[i,:,:].T, cmap, grid, shelf_depth, axarr.flatten()[i], extend=extend, levels=levels)
            else:
                mappable = plot_stuff(xe, ye, H[i,:,:].T, cmap, grid, shelf_depth, axarr.flatten()[i], extend=extend)
        #axarr.flatten()[i].set_title(np.nanmax(H[i,:,:,itind]))
        # Add coastline area if applicable
        if 'coast' in whichtype:
            coastloc = whichtype.split('coast')[-1]
            pts = np.load('calcs/' + coastloc + 'pts.npz')[coastloc]
            axarr.flatten()[i].plot(pts[:,0], pts[:,1], color='0.0', lw=3)
            # verts = np.vstack((pts[:,0], pts[:,1]))
            # # Form path
            # path = Path(verts.T)
            # if not path.contains_point(np.vstack((xp[jd,it],yp[jd,it]))):

        # Overlay mean wind arrows
        if addwind:
            # Right now is just for cross, interannual, winter
            year = years[i]
            # year = File.split('/')[-1].split('-')[0]
            season = whichtime.split('-')[-1]
            wind = np.load('../txla_plots/calcs/wind_stress/1st/jfm/' + str(year) + season +  '.npz')
            x = wind['x']; y = wind['y']; u = wind['u']; v = wind['v']
            q = axarr.flatten()[i].quiver(x, y, u, v, color = '0.3',
                        pivot='middle', zorder=1e35, width=0.003)
                        # scale=1.0/scale, pivot='middle', zorder=1e35, width=0.003)

            # if year == 2008:
            #     plt.quiverkey(q, 0.85, 0.07, 0.1, label=r'0.1 N m$^{2}$', coordinates='axes')



    # Add colorbar
    plot_colorbar(fig, mappable, whichtype, whichdir=whichdir, whichtime=whichtime)
    # pdb.set_trace()

    # save and close
    plot_finish(fig, whichtype, whichtime, shelf_depth, itind, r, numdays)
示例#42
0
def lucky_frame(
	im, 							# In electron counts/s.
	psf, 							# Normalised.
	scale_factor, 					
	t_exp, 
	final_sz,
	tt = np.array([0, 0]),
	im_star = None,					# In electron counts/s.					
	noise_frame_gain_multiplied = 0,		# Noise injected into the system that is multiplied up by the detector gain after conversion to counts via a Poisson distribution, e.g. sky background, emission from telescope, etc. Must have shape final_sz. It is assumed that this noise frame has already been multiplied up by the detector gain!
	noise_frame_post_gain = 0,		# Noise injected into the system after gain multiplication, e.g. read noise. Must have shape final_sz.
	gain = 1,						# Detector gain.
	detector_saturation=np.inf,		# Detector saturation.
	plate_scale_as_px_conv = 1,		# Only used for plotting.
	plate_scale_as_px = 1,			# Only used for plotting.
	plotit=False):
	""" 
		This function can be used to generate a short-exposure 'lucky' image that can be input to the Lucky Imaging algorithms.
			Input: 	one 'raw' countrate image of a galaxy; one PSF with which to convolve it (at the same plate scale)
			Output: a 'Lucky' exposure. 			
			Process: convolve with PSF --> resize to detector --> add tip and tilt (from a premade vector of tip/tilt values) --> convert to counts --> add noise --> subtract the master sky/dark current. 
	"""	
	# Convolve with PSF.
	im_raw = im
	im_convolved = obssim.convolve_psf(im_raw, psf)

	# Add a star to the field. We need to add the star at the convolution plate scale BEFORE we resize down because of the tip-tilt adding step!
	if is_numlike(im_star):
		if im_star.shape != im_convolved.shape:
			print("ERROR: the input image of the star MUST have the same size and plate scale as the image of the galaxy after convolution!")
			raise UserWarning
		im_convolved += im_star

	# Resize to detector (+ edge buffer).
	im_resized = imutils.fourier_resize(
		im = im_convolved,
		scale_factor = scale_factor,
		conserve_pixel_sum = True)

	# Add tip and tilt. To avoid edge effects, max(tt) should be less than or equal to the edge buffer.
	edge_buffer_px = (im.shape[0] - final_sz[0]) / 2
	if edge_buffer_px > 0 and max(tt) > edge_buffer_px:
		print("WARNING: the edge buffer is less than the supplied tip and tilt by a margin of {:.2f} pixels! Shifted image will be clipped.".format(np.abs(edge_buffer_px - max(tt))))
	im_tt = obssim.add_tt(image = im_resized, tt_idxs = tt)[0]	
	# Crop back down to the detector size.
	if edge_buffer_px > 0:
		im_tt = imutils.centre_crop(im_tt, final_sz)	
	# Convert to counts. Note that we apply the gain AFTER we convert to integer
	# counts.
	im_counts = etcutils.expected_count_to_count(im_tt, t_exp = t_exp) * gain
	# Add the pre-gain noise. Here, we assume that the noise frame has already 
	# been multiplied by the gain before being passed into this function.
	im_noisy = im_counts + noise_frame_gain_multiplied
	# Add the post-gain noise (i.e. read noise)
	im_noisy += noise_frame_post_gain
	# Account for detector saturation
	im_noisy = np.clip(im_noisy, a_min=0, a_max=detector_saturation)

	if plotit:
		plate_scale_as_px = plate_scale_as_px_conv * scale_factor
		# Plotting
		mu.newfigure(1,3)
		plt.suptitle('Convolving input image with PSF and resizing to detector')
		mu.astroimshow(im=im_raw, 
			title='Truth image (electrons/s)', 
			plate_scale_as_px = plate_scale_as_px_conv, 
			colorbar_on=True, 
			subplot=131)
		mu.astroimshow(im=psf, 
			title='Point spread function (normalised)', 
			plate_scale_as_px = plate_scale_as_px_conv, 
			colorbar_on=True, 
			subplot=132)
		# mu.astroimshow(im=im_convolved, 
		# 	title='Star added, convolved with PSF (electrons/s)', 
		# 	plate_scale_as_px = plate_scale_as_px_conv, 
		# 	colorbar_on=True, 
		# 	subplot=143)
		mu.astroimshow(im=im_resized, 
			title='Resized to detector plate scale (electrons/s)', 
			plate_scale_as_px=plate_scale_as_px, 
			colorbar_on=True, 
			subplot=133)

		# Zooming in on the galaxy
		# mu.newfigure(1,4)
		# plt.suptitle('Convolving input image with PSF and resizing to detector')
		# mu.astroimshow(im=imutils.centre_crop(im=im_raw, units='arcsec', plate_scale_as_px=plate_scale_as_px_conv, sz_final=(6, 6)), title='Raw input image (electrons/s)', plate_scale_as_px = plate_scale_as_px_conv, colorbar_on=True, subplot=141)
		# mu.astroimshow(im=psf, title='Point spread function (normalised)', plate_scale_as_px = plate_scale_as_px_conv, colorbar_on=True, subplot=142)
		# mu.astroimshow(im=imutils.centre_crop(im=im_convolved, units='arcsec', plate_scale_as_px=plate_scale_as_px_conv, sz_final=(6, 6)), title='Star added, convolved with PSF (electrons/s)', plate_scale_as_px = plate_scale_as_px_conv, colorbar_on=True, subplot=143)
		# mu.astroimshow(im=imutils.centre_crop(im=im_resized, units='arcsec', plate_scale_as_px=plate_scale_as_px, sz_final=(6, 6)), title='Resized to detector plate scale (electrons/s)', plate_scale_as_px=plate_scale_as_px, colorbar_on=True, subplot=144)

		mu.newfigure(1,3)
		plt.suptitle('Adding tip and tilt, converting to integer counts and adding noise')
		mu.astroimshow(im=im_tt, 
			title='Atmospheric tip and tilt added (electrons/s)', 
			plate_scale_as_px=plate_scale_as_px, 
			colorbar_on=True,
			subplot=131)
		mu.astroimshow(im=im_counts, 
			title=r'Converted to integer counts and gain-multiplied by %d (electrons)' % gain, 
			plate_scale_as_px=plate_scale_as_px, 
			colorbar_on=True, 
			subplot=132)
		mu.astroimshow(im=im_noisy, 
			title='Noise added (electrons)', 
			plate_scale_as_px=plate_scale_as_px, 
			colorbar_on=True, 
			subplot=133)

		# plt.subplot(1,4,4)
		plt.figure()
		x = np.linspace(-im_tt.shape[0]/2, +im_tt.shape[0]/2, im_tt.shape[0]) * plate_scale_as_px
		plt.plot(x, im_tt[:, im_tt.shape[1]/2], 'g', label='Electron count rate')
		plt.plot(x, im_counts[:, im_tt.shape[1]/2], 'b', label='Converted to integer counts ($t_{exp} = %.2f$ s)' % t_exp)
		plt.plot(x, im_noisy[:, im_tt.shape[1]/2], 'r', label='Noise added')
		plt.xlabel('arcsec')
		plt.ylabel('Pixel value (electrons)')
		plt.title('Linear profiles')
		plt.axis('tight')
		plt.legend(loc='lower left')
		mu.show_plot()

	return im_noisy
示例#43
0
def lucky_frame(
        im,  # In electron counts/s.
        psf,  # Normalised.
        scale_factor,
        t_exp,
        final_sz,
        tt=np.array([0, 0]),
        im_star=None,  # In electron counts/s.					
        noise_frame_gain_multiplied=0,  # Noise injected into the system that is multiplied up by the detector gain after conversion to counts via a Poisson distribution, e.g. sky background, emission from telescope, etc. Must have shape final_sz. It is assumed that this noise frame has already been multiplied up by the detector gain!
        noise_frame_post_gain=0,  # Noise injected into the system after gain multiplication, e.g. read noise. Must have shape final_sz.
        gain=1,  # Detector gain.
        detector_saturation=np.inf,  # Detector saturation.
        plate_scale_as_px_conv=1,  # Only used for plotting.
        plate_scale_as_px=1,  # Only used for plotting.
        plotit=False):
    """ 
		This function can be used to generate a short-exposure 'lucky' image that can be input to the Lucky Imaging algorithms.
			Input: 	one 'raw' countrate image of a galaxy; one PSF with which to convolve it (at the same plate scale)
			Output: a 'Lucky' exposure. 			
			Process: convolve with PSF --> resize to detector --> add tip and tilt (from a premade vector of tip/tilt values) --> convert to counts --> add noise --> subtract the master sky/dark current. 
	"""
    # Convolve with PSF.
    im_raw = im
    im_convolved = obssim.convolve_psf(im_raw, psf)

    # Add a star to the field. We need to add the star at the convolution plate scale BEFORE we resize down because of the tip-tilt adding step!
    if is_numlike(im_star):
        if im_star.shape != im_convolved.shape:
            print(
                "ERROR: the input image of the star MUST have the same size and plate scale as the image of the galaxy after convolution!"
            )
            raise UserWarning
        im_convolved += im_star

    # Resize to detector (+ edge buffer).
    im_resized = imutils.fourier_resize(im=im_convolved,
                                        scale_factor=scale_factor,
                                        conserve_pixel_sum=True)

    # Add tip and tilt. To avoid edge effects, max(tt) should be less than or equal to the edge buffer.
    edge_buffer_px = (im.shape[0] - final_sz[0]) / 2
    if edge_buffer_px > 0 and max(tt) > edge_buffer_px:
        print(
            "WARNING: the edge buffer is less than the supplied tip and tilt by a margin of {:.2f} pixels! Shifted image will be clipped."
            .format(np.abs(edge_buffer_px - max(tt))))
    im_tt = obssim.add_tt(image=im_resized, tt_idxs=tt)[0]
    # Crop back down to the detector size.
    if edge_buffer_px > 0:
        im_tt = imutils.centre_crop(im_tt, final_sz)
    # Convert to counts. Note that we apply the gain AFTER we convert to integer
    # counts.
    im_counts = etcutils.expected_count_to_count(im_tt, t_exp=t_exp) * gain
    # Add the pre-gain noise. Here, we assume that the noise frame has already
    # been multiplied by the gain before being passed into this function.
    im_noisy = im_counts + noise_frame_gain_multiplied
    # Add the post-gain noise (i.e. read noise)
    im_noisy += noise_frame_post_gain
    # Account for detector saturation
    im_noisy = np.clip(im_noisy, a_min=0, a_max=detector_saturation)

    if plotit:
        plate_scale_as_px = plate_scale_as_px_conv * scale_factor
        # Plotting
        mu.newfigure(1, 3)
        plt.suptitle(
            'Convolving input image with PSF and resizing to detector')
        mu.astroimshow(im=im_raw,
                       title='Truth image (electrons/s)',
                       plate_scale_as_px=plate_scale_as_px_conv,
                       colorbar_on=True,
                       subplot=131)
        mu.astroimshow(im=psf,
                       title='Point spread function (normalised)',
                       plate_scale_as_px=plate_scale_as_px_conv,
                       colorbar_on=True,
                       subplot=132)
        # mu.astroimshow(im=im_convolved,
        # 	title='Star added, convolved with PSF (electrons/s)',
        # 	plate_scale_as_px = plate_scale_as_px_conv,
        # 	colorbar_on=True,
        # 	subplot=143)
        mu.astroimshow(im=im_resized,
                       title='Resized to detector plate scale (electrons/s)',
                       plate_scale_as_px=plate_scale_as_px,
                       colorbar_on=True,
                       subplot=133)

        # Zooming in on the galaxy
        # mu.newfigure(1,4)
        # plt.suptitle('Convolving input image with PSF and resizing to detector')
        # mu.astroimshow(im=imutils.centre_crop(im=im_raw, units='arcsec', plate_scale_as_px=plate_scale_as_px_conv, sz_final=(6, 6)), title='Raw input image (electrons/s)', plate_scale_as_px = plate_scale_as_px_conv, colorbar_on=True, subplot=141)
        # mu.astroimshow(im=psf, title='Point spread function (normalised)', plate_scale_as_px = plate_scale_as_px_conv, colorbar_on=True, subplot=142)
        # mu.astroimshow(im=imutils.centre_crop(im=im_convolved, units='arcsec', plate_scale_as_px=plate_scale_as_px_conv, sz_final=(6, 6)), title='Star added, convolved with PSF (electrons/s)', plate_scale_as_px = plate_scale_as_px_conv, colorbar_on=True, subplot=143)
        # mu.astroimshow(im=imutils.centre_crop(im=im_resized, units='arcsec', plate_scale_as_px=plate_scale_as_px, sz_final=(6, 6)), title='Resized to detector plate scale (electrons/s)', plate_scale_as_px=plate_scale_as_px, colorbar_on=True, subplot=144)

        mu.newfigure(1, 3)
        plt.suptitle(
            'Adding tip and tilt, converting to integer counts and adding noise'
        )
        mu.astroimshow(im=im_tt,
                       title='Atmospheric tip and tilt added (electrons/s)',
                       plate_scale_as_px=plate_scale_as_px,
                       colorbar_on=True,
                       subplot=131)
        mu.astroimshow(
            im=im_counts,
            title=
            r'Converted to integer counts and gain-multiplied by %d (electrons)'
            % gain,
            plate_scale_as_px=plate_scale_as_px,
            colorbar_on=True,
            subplot=132)
        mu.astroimshow(im=im_noisy,
                       title='Noise added (electrons)',
                       plate_scale_as_px=plate_scale_as_px,
                       colorbar_on=True,
                       subplot=133)

        # plt.subplot(1,4,4)
        plt.figure()
        x = np.linspace(-im_tt.shape[0] / 2, +im_tt.shape[0] / 2,
                        im_tt.shape[0]) * plate_scale_as_px
        plt.plot(x,
                 im_tt[:, im_tt.shape[1] / 2],
                 'g',
                 label='Electron count rate')
        plt.plot(x,
                 im_counts[:, im_tt.shape[1] / 2],
                 'b',
                 label='Converted to integer counts ($t_{exp} = %.2f$ s)' %
                 t_exp)
        plt.plot(x, im_noisy[:, im_tt.shape[1] / 2], 'r', label='Noise added')
        plt.xlabel('arcsec')
        plt.ylabel('Pixel value (electrons)')
        plt.title('Linear profiles')
        plt.axis('tight')
        plt.legend(loc='lower left')
        mu.show_plot()

    return im_noisy