예제 #1
0
 def get_epsilon_nu(self):
     '''
     A method for generating the sequences of nu and epsilon using some continuation rule.
     '''
     str_method = self.get_val('nuepsilonmethod', False)
     decay = self.get_val('decay', True)
     epsilon_start = self.get_val('epsilonstart', True)
     epsilon_stop = self.get_val('epsilonstop', True)
     nu_start = self.get_val('nustart', True)
     nu_stop = self.get_val('nustop', True)
     if str_method == 'geometric':
         epsilon = np.asarray([nmax(epsilon_start * (decay ** i),epsilon_stop) \
                               for i in arange(self.int_iterations+1)])
         nu = np.asarray([nmax(nu_start * (decay ** i),nu_stop) \
                               for i in arange(self.int_iterations+1)])
     elif str_method == 'exponential':
         epsilon = np.asarray([epsilon_start * exp(-i / decay) + epsilon_stop \
                               for i in arange(self.int_iterations+1)])
         nu = np.asarray([nu_start * exp(-i / decay) + nu_stop \
                               for i in arange(self.int_iterations+1)])
     elif str_method == 'fixed':
         epsilon = epsilon_start * np.ones(self.int_iterations + 1, )
         nu = nu_start * np.ones(self.int_iterations + 1, )
     else:
         raise Exception('no such continuation parameter rule')
     return epsilon, nu
예제 #2
0
def geweke_plot(data, name, format='png', suffix='-diagnostic', path='./', fontmap = None, 
    verbose=1):
    # Generate Geweke (1992) diagnostic plots

    if fontmap is None: fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}

    # Generate new scatter plot
    figure()
    x, y = transpose(data)
    scatter(x.tolist(), y.tolist())

    # Plot options
    xlabel('First iteration', fontsize='x-small')
    ylabel('Z-score for %s' % name, fontsize='x-small')

    # Plot lines at +/- 2 sd from zero
    pyplot((nmin(x), nmax(x)), (2, 2), '--')
    pyplot((nmin(x), nmax(x)), (-2, -2), '--')

    # Set plot bound
    ylim(min(-2.5, nmin(y)), max(2.5, nmax(y)))
    xlim(0, nmax(x))

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith('/'):
        path += '/'
    savefig("%s%s%s.%s" % (path, name, suffix, format))
예제 #3
0
파일: Matplot.py 프로젝트: wqren/pymc
def geweke_plot(data,
                name,
                format='png',
                suffix='-diagnostic',
                path='./',
                fontmap=None,
                verbose=1):
    # Generate Geweke (1992) diagnostic plots

    if fontmap is None: fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}

    # Generate new scatter plot
    figure()
    x, y = transpose(data)
    scatter(x.tolist(), y.tolist())

    # Plot options
    xlabel('First iteration', fontsize='x-small')
    ylabel('Z-score for %s' % name, fontsize='x-small')

    # Plot lines at +/- 2 sd from zero
    pyplot((nmin(x), nmax(x)), (2, 2), '--')
    pyplot((nmin(x), nmax(x)), (-2, -2), '--')

    # Set plot bound
    ylim(min(-2.5, nmin(y)), max(2.5, nmax(y)))
    xlim(0, nmax(x))

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith('/'):
        path += '/'
    savefig("%s%s%s.%s" % (path, name, suffix, format))
예제 #4
0
파일: Matplot.py 프로젝트: shfengcj/pymc
def geweke_plot(data,
                name,
                format='png',
                suffix='-diagnostic',
                path='./',
                fontmap=None):
    '''
    Generate Geweke (1992) diagnostic plots.
    
    :Arguments:
        data: list
            List (or list of lists for vector-valued variables) of Geweke diagnostics, output
            from the `pymc.diagnostics.geweke` function .

        name: string
            The name of the plot.

        format (optional): string
            Graphic output format (defaults to png).

        suffix (optional): string
            Filename suffix (defaults to "-diagnostic").

        path (optional): string
            Specifies location for saving plots (defaults to local directory).

        fontmap (optional): dict
            Font map for plot.
    
    '''

    if fontmap is None:
        fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}

    # Generate new scatter plot
    figure()
    x, y = transpose(data)
    scatter(x.tolist(), y.tolist())

    # Plot options
    xlabel('First iteration', fontsize='x-small')
    ylabel('Z-score for %s' % name, fontsize='x-small')

    # Plot lines at +/- 2 sd from zero
    pyplot((nmin(x), nmax(x)), (2, 2), '--')
    pyplot((nmin(x), nmax(x)), (-2, -2), '--')

    # Set plot bound
    ylim(min(-2.5, nmin(y)), max(2.5, nmax(y)))
    xlim(0, nmax(x))

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith('/'):
        path += '/'
    savefig("%s%s%s.%s" % (path, name, suffix, format))
예제 #5
0
def geweke_plot(data,
                name,
                format='png',
                suffix='-diagnostic',
                path='./',
                fontmap=None):
    '''
    Generate Geweke (1992) diagnostic plots.
    
    :Arguments:
        data: list
            List (or list of lists for vector-valued variables) of Geweke diagnostics, output
            from the `pymc.diagnostics.geweke` function .

        name: string
            The name of the plot.

        format (optional): string
            Graphic output format (defaults to png).

        suffix (optional): string
            Filename suffix (defaults to "-diagnostic").

        path (optional): string
            Specifies location for saving plots (defaults to local directory).

        fontmap (optional): dict
            Font map for plot.
    
    '''

    if fontmap is None:
        fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}

    # Generate new scatter plot
    figure()
    x, y = transpose(data)
    scatter(x.tolist(), y.tolist())

    # Plot options
    xlabel('First iteration', fontsize='x-small')
    ylabel('Z-score for %s' % name, fontsize='x-small')

    # Plot lines at +/- 2 sd from zero
    pyplot((nmin(x), nmax(x)), (2, 2), '--')
    pyplot((nmin(x), nmax(x)), (-2, -2), '--')

    # Set plot bound
    ylim(min(-2.5, nmin(y)), max(2.5, nmax(y)))
    xlim(0, nmax(x))

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith('/'):
        path += '/'
    savefig("%s%s%s.%s" % (path, name, suffix, format))
예제 #6
0
    def csv_output(self):
        
        """This method is used to report the results of
        a subscription test to a csv file"""

        # determine the file name
        csv_filename = "subscription-%s-%siter-%s-%s.csv" % (self.subscriptiontype,
                                                      self.iterations,
                                                      self.chart_type.lower(),
                                                      self.testdatetime)

        # initialize the csv file
        csvfile_stream = open(csv_filename, "w")
        csvfile_writer = csv.writer(csvfile_stream, delimiter=',', quoting=csv.QUOTE_MINIMAL)

        # iterate over the SIBs
        for sib in self.results.keys():                    
                                     
            row = [sib]
            
            # add all the times
            for value in self.results[sib]:
                row.append(value)

            # add the mean, min, max and variance value of the times to the row
            row.append(round(nmean(self.results[sib]),3))                
            row.append(round(nmin(self.results[sib]),3))                
            row.append(round(nmax(self.results[sib]),3))                
            row.append(round(nvar(self.results[sib]),3))                

            # write the row
            csvfile_writer.writerow(row)
                
        # close the csv file
        csvfile_stream.close()
예제 #7
0
 def __gamma_ratio(self, x, y):
     
     module = nabs(nmax((x, y)))
 
     if module <= 100.0:
         return self.__gamma(x) / self.__gamma(y)
     else:
         return (power(2, x - y) * 
                  self.__gamma_ratio(x * 0.5, y * 0.5) * 
                  self.__gamma_ratio(x * 0.5 + 0.5, y * 0.5 + 0.5))
예제 #8
0
 def imexamine(self, src):
     try:
         data = self.fit.data(src, table=False)
         the_min = nmin(data)
         the_max = nmax(data)
         the_mea = nmea(data)
         the_std = nstd(data)
         the_med = nmed(data)
         return ([the_mea, the_med, the_std, the_min, the_max])
     except Exception as e:
         self.etc.log(e)
예제 #9
0
def genSampling(pdf, nitn, tol):
    pdf[pdf > 1] = 1
    K = np.sum(pdf.flatten())

    minIntr = 1e99
    minIntrVec = zeros(pdf.shape)
    stat = np.zeros(nitn, )
    for n in np.arange(0, nitn):
        tmp = zeros(pdf.shape)
        while abs(np.sum(tmp.flatten()) - K) > tol:
            tmp = rand(*pdf.shape) < pdf

        TMP = ifft2(tmp / pdf)
        if nmax(nabs(TMP.flatten()[1:])) < minIntr:
            minIntr = nmax(nabs(TMP.flatten()[1:]))
            minIntrVec = tmp
        stat[n] = nmax(nabs(TMP.flatten()[1:]))

    actpctg = np.sum(minIntrVec.flatten()) / float(minIntrVec.size)
    mask = minIntrVec
    return mask, stat, actpctg
예제 #10
0
파일: Matplot.py 프로젝트: studentmicky/gbd
def discrepancy_plot(data,
                     name,
                     report_p=True,
                     format='png',
                     suffix='-gof',
                     path='./',
                     fontmap={
                         1: 10,
                         2: 8,
                         3: 6,
                         4: 5,
                         5: 4
                     },
                     verbose=1):
    # Generate goodness-of-fit deviate scatter plot
    if verbose > 0:
        print 'Plotting', name + suffix

    # Generate new scatter plot
    figure()
    try:
        x, y = transpose(data)
    except ValueError:
        x, y = data
    scatter(x, y)

    # Plot x=y line
    lo = nmin(ravel(data))
    hi = nmax(ravel(data))
    datarange = hi - lo
    lo -= 0.1 * datarange
    hi += 0.1 * datarange
    pyplot((lo, hi), (lo, hi))

    # Plot options
    xlabel('Observed deviates', fontsize='x-small')
    ylabel('Simulated deviates', fontsize='x-small')

    if report_p:
        # Put p-value in legend
        count = sum(s > o for o, s in zip(x, y))
        text(lo + 0.1 * datarange,
             hi - 0.1 * datarange,
             'p=%.3f' % (count / len(x)),
             horizontalalignment='center',
             fontsize=10)

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith('/'):
        path += '/'
    savefig("%s%s%s.%s" % (path, name, suffix, format))
예제 #11
0
    def update(self, dict_in):
        """
        Expects a single value or array. If array, store the whole vector and stop.
        """
        if self.data == []:
            self.xshape = dict_in['x'].shape
            self.x = dict_in['x'].flatten()
            if self.peak == 0:
                self.peak = nmax(self.x)
            if self.bordercrop != 0:
                # self.slices=tuple([slice(self.bordercrop,-self.bordercrop)
                #                    for i in xrange(len(self.xshape))])
                self.crop_center_size = tuple(
                    [el - 2 * self.bordercrop for el in self.xshape])
                self.x = crop_center(dict_in['x'], self.crop_center_size)
                self.peak = nmax(self.x)
            if self.bytecompare:
                self.x = np.asarray(self.x / self.peak * 255.0, dtype='uint8')
                self.peak = nmax(self.x)
            if self.get_val('peak', True) > 0:
                self.peak = self.get_val('peak', True)
            self.x = self.x.flatten()

        if dict_in['x_n'].shape != self.xshape:
            x_n = crop_center(dict_in['x_n'], self.xshape).flatten()
        else:
            x_n = dict_in['x_n']
            if self.bordercrop != 0:
                x_n = crop_center(x_n, self.crop_center_size)
            if self.bytecompare:
                x_n = np.asarray(x_n, dtype='uint8')
            x_n = x_n.flatten()
        mse = mean((x_n - self.x)**2)
        if mse == 0:
            snr_db = np.inf
        else:
            snr_db = 10 * log10((self.peak**2) / mse)
        value = snr_db
        self.data.append(value)
        super(PSNR, self).update()
예제 #12
0
 def fits_stat(self, src):
     self.etc.log("Getting Stats from {}".format(src))
     try:
         hdu = fts.open(src)
         image_data = hdu[0].data
         return ({
             'Min': nmin(image_data),
             'Max': nmax(image_data),
             'Mean': nmea(image_data),
             'Stdev': nstd(image_data)
         })
     except Exception as e:
         self.etc.log(e)
예제 #13
0
def get_network_extents(net):
    '''
    For a given Emme Network, find the envelope (extents) of all of its elements.
    Includes link vertices as well as nodes.
    
    Args:
        -net: An Emme Network Object
    
    Returns:
        minx, miny, maxx, maxy tuple
    '''
    xs, ys = [], []
    for node in net.nodes():
        xs.append(node.x)
        ys.append(node.y)
    for link in net.links():
        for x, y in link.vertices:
            xs.append(x)
            ys.append(y)
    xa = array(xs)
    ya = array(ys)
    
    return nmin(xa) - 1.0, nmin(ya) - 1.0, nmax(xa) + 1.0, nmax(ya) + 1.0
예제 #14
0
def get_network_extents(net):
    '''
    For a given Emme Network, find the envelope (extents) of all of its elements.
    Includes link vertices as well as nodes.
    
    Args:
        -net: An Emme Network Object
    
    Returns:
        minx, miny, maxx, maxy tuple
    '''
    xs, ys = [], []
    for node in net.nodes():
        xs.append(node.x)
        ys.append(node.y)
    for link in net.links():
        for x, y in link.vertices:
            xs.append(x)
            ys.append(y)
    xa = array(xs)
    ya = array(ys)
    
    return nmin(xa) - 1.0, nmin(ya) - 1.0, nmax(xa) + 1.0, nmax(ya) + 1.0
예제 #15
0
 def stats(self, file):
     """Returns statistics of a given fit file."""
     self.logger.info("Getting Stats from {}".format(file))
     try:
         hdu = fts.open(file)
         image_data = hdu[0].data
         return {
             'Min': nmin(image_data),
             'Max': nmax(image_data),
             'Median': nmed(image_data),
             'Mean': nmea(image_data),
             'Stdev': nstd(image_data)
         }
     except Exception as e:
         self.logger.error(e)
예제 #16
0
파일: Matplot.py 프로젝트: roban/pymc
def discrepancy_plot(
    data, name="discrepancy", report_p=True, format="png", suffix="-gof", path="./", fontmap=None, verbose=1
):
    # Generate goodness-of-fit deviate scatter plot

    if verbose > 0:
        print_("Plotting", name + suffix)

    if fontmap is None:
        fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}

    # Generate new scatter plot
    figure()
    try:
        x, y = transpose(data)
    except ValueError:
        x, y = data
    scatter(x, y)

    # Plot x=y line
    lo = nmin(ravel(data))
    hi = nmax(ravel(data))
    datarange = hi - lo
    lo -= 0.1 * datarange
    hi += 0.1 * datarange
    pyplot((lo, hi), (lo, hi))

    # Plot options
    xlabel("Observed deviates", fontsize="x-small")
    ylabel("Simulated deviates", fontsize="x-small")

    if report_p:
        # Put p-value in legend
        count = sum(s > o for o, s in zip(x, y))
        text(
            lo + 0.1 * datarange,
            hi - 0.1 * datarange,
            "p=%.3f" % (count / len(x)),
            horizontalalignment="center",
            fontsize=10,
        )

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith("/"):
        path += "/"
    savefig("%s%s%s.%s" % (path, name, suffix, format))
예제 #17
0
    def csv_output(self):
        
        """This method is used to report the results of
        an update test to a csv file"""

        # determine the file name
        csv_filename = "update-%s-%sstep-%smax-%siter-%s-%s.csv" % (self.updatetype,
                                                                    self.step,
                                                                    self.limit,
                                                                    self.iterations,
                                                                    self.chart_type.lower(),
                                                                    self.testdatetime)

        # initialize the csv file
        csvfile_stream = open(csv_filename, "w")
        csvfile_writer = csv.writer(csvfile_stream, delimiter=',', quoting=csv.QUOTE_MINIMAL)

        # iterate over the SIBs
        for sib in self.results.keys():                    
                         
            # iterate over the possible block lengths
            for triple_length in sorted(self.results[sib].keys(), key=int):
            
                row = [sib]
    
                # add the length of the block to the row
                row.append(triple_length)

                # add all the times
                for value in self.results[sib][triple_length]:
                    row.append(value)

                # add the mean value of the times to the row
                row.append(round(nmean(self.results[sib][triple_length]),3))                
                row.append(round(nmin(self.results[sib][triple_length]),3))                
                row.append(round(nmax(self.results[sib][triple_length]),3))                
                row.append(round(nvar(self.results[sib][triple_length]),3))                

                # write the row
                csvfile_writer.writerow(row)

        # close the csv file
        csvfile_stream.close()
예제 #18
0
파일: example.py 프로젝트: AtomAleks/PyProp
def CompareFortran(**args):
	conf = pyprop.Load("config_compare_fortran.ini")
	prop = pyprop.Problem(conf)
	prop.SetupStep()

	init = prop.psi.Copy()

	for t in prop.Advance(5):
		corr = abs(prop.psi.InnerProduct(init))**2
		print "Time = %f, initial state correlation = %f" % (t, corr)

	corr = abs(prop.psi.InnerProduct(init))**2
	t = prop.PropagatedTime
	print "Time = %f, initial state correlation = %f" % (t, corr)

	#Load fortran data and compare
	fdata = pylab.load("fortran_propagation.dat")
	print "Max difference pyprop/fortran: %e" % nmax(abs(prop.psi.GetData())**2 - fdata[1:])

	return prop
예제 #19
0
파일: __init__.py 프로젝트: wxgeo/geophar
    def _creer_cmap(self, seuils):
        zmax = nmax(self._Z)
        zmin = nmin(self._Z)
        delta = zmax - zmin
        # On les ramène entre 0 et 1 par transformation affine
        if delta:
            a = 1/delta
            b = -zmin/delta
        seuils = [0] + [a*z + b for z in seuils if zmin < z < zmax] + [1] # NB: < et pas <=
        print(seuils)
        cdict = {'red': [], 'green': [], 'blue': []}
        def add_col(val, color1, color2):
            cdict['red'].append((val, color1[0], color2[0]))
            cdict['green'].append((val, color1[1], color2[1]))
            cdict['blue'].append((val, color1[2], color2[2]))

        n = len(self.couleurs)
        for i, seuil in enumerate(seuils):
            add_col(seuil, self.couleurs[(i - 1)%n], self.couleurs[i%n])
        return LinearSegmentedColormap('seuils', cdict, 256)
예제 #20
0
def CompareFortran(**args):
    conf = pyprop.Load("config_compare_fortran.ini")
    prop = pyprop.Problem(conf)
    prop.SetupStep()

    init = prop.psi.Copy()

    for t in prop.Advance(5):
        corr = abs(prop.psi.InnerProduct(init))**2
        print "Time = %f, initial state correlation = %f" % (t, corr)

    corr = abs(prop.psi.InnerProduct(init))**2
    t = prop.PropagatedTime
    print "Time = %f, initial state correlation = %f" % (t, corr)

    #Load fortran data and compare
    fdata = pylab.load("fortran_propagation.dat")
    print "Max difference pyprop/fortran: %e" % nmax(
        abs(prop.psi.GetData())**2 - fdata[1:])

    return prop
예제 #21
0
def genPDF(imSize, p, pctg, distType=2, radius=0, seed=0):
    minval = 0
    maxval = 1
    val = 0.5

    if len(imSize) == 1:
        imSize = [imSize, 1]
    sx = imSize[0]
    sy = imSize[1]
    PCTG = np.floor(pctg * sx * sy)
    if not np.any(np.asarray(imSize) == 1):
        x, y = np.meshgrid(np.linspace(-1, 1, sy), np.linspace(-1, 1, sx))
        if distType == 1:
            r = np.fmax(nabs(x), nabs(y))
        else:
            r = sqrt(x**2 + y**2)
            r = r / nmax(nabs(r.flatten()))
    else:
        r = nabs(np.linspace(-1, 1, max(sx, sy)))

    idx = np.where(r < radius)
    pdf = (1 - r)**p
    pdf[idx] = 1
    if np.floor(sum(pdf.flatten())) > PCTG:
        raise ValueError('infeasible without undersampling dc, increase p')

    # begin bisection
    while 1:
        val = minval / 2.0 + maxval / 2.0
        pdf = (1 - r)**p + val
        pdf[pdf > 1] = 1
        pdf[idx] = 1
        N = np.floor(sum(pdf.flatten()))
        if N > PCTG:
            maxval = val
        if N < PCTG:
            minval = val
        if N == PCTG:
            break
    return pdf
예제 #22
0
    def _creer_cmap(self, seuils):
        zmax = nmax(self._Z)
        zmin = nmin(self._Z)
        delta = zmax - zmin
        # On les ramène entre 0 et 1 par transformation affine
        if delta:
            a = 1 / delta
            b = -zmin / delta
        seuils = [0] + [a * z + b for z in seuils if zmin < z < zmax
                        ] + [1]  # NB: < et pas <=
        print seuils
        cdict = {'red': [], 'green': [], 'blue': []}

        def add_col(val, color1, color2):
            cdict['red'].append((val, color1[0], color2[0]))
            cdict['green'].append((val, color1[1], color2[1]))
            cdict['blue'].append((val, color1[2], color2[2]))

        n = len(self.couleurs)
        for i, seuil in enumerate(seuils):
            add_col(seuil, self.couleurs[(i - 1) % n], self.couleurs[i % n])
        return LinearSegmentedColormap('seuils', cdict, 256)
예제 #23
0
파일: Matplot.py 프로젝트: along1x/pymc
def discrepancy_plot(data, name, report_p=True, format='png', suffix='-gof', path='./', fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}, verbose=1):
    # Generate goodness-of-fit deviate scatter plot
    if verbose>0:
        print 'Plotting', name+suffix

    # Generate new scatter plot
    figure()
    try:
        x, y = transpose(data)
    except ValueError:
        x, y = data
    scatter(x, y)

    # Plot x=y line
    lo = nmin(ravel(data))
    hi = nmax(ravel(data))
    datarange = hi-lo
    lo -= 0.1*datarange
    hi += 0.1*datarange
    pyplot((lo, hi), (lo, hi))

    # Plot options
    xlabel('Observed deviates', fontsize='x-small')
    ylabel('Simulated deviates', fontsize='x-small')

    if report_p:
        # Put p-value in legend
        count = sum(s>o for o,s in zip(x,y))
        text(lo+0.1*datarange, hi-0.1*datarange,
             'p=%.3f' % (count/len(x)), horizontalalignment='center',
             fontsize=10)

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith('/'):
        path += '/'
    savefig("%s%s%s.%s" % (path, name, suffix, format))
예제 #24
0
def powerspectrum(*args,**kw):
  """Calling Sequence:
    alpha, beta, power, [freq] = powerspectrum(time, data, [freq,] weights=None, timeit=True, ofac=4)

  Input:
    time       : Time array
    data       : Data array
    freq       : Frequency array with frequencies to be evaluated. If not present one will be generated

  Output:
    alpha      : alpha coefficient (see spectrum_core)
    beta       : beta coefficient (see spectrum_core)
    power      : Power Spectrum
    freq       : Generated frequency array (if not present in input)

  Keywords:
    weights    : Array with weights to be used. If this array is not present no weighting will be used
    timeit     : If true prints timing information
    ofac       : Oversampling parameter

  Description:
    Main Routine for evaluation of power-spectra
  """

  # ------ Starting time
  t0 = systemtime()
  
  # ------- Handle keywords
  weights = kw.get('weights',None)
  timeit  = kw.get('timeit',True)
  ofac    = kw.get('ofac',4.)

  # ------ Handle arguments
  if len(args) == 2:
    time = args[0]
    data = args[1]
  elif len(args) == 3:
    time = args[0]
    data = args[1]
    freq = args[2]
  else:
    raise InputError('Wrong number of inputs')

  time = array(time,dtype=float64).squeeze()
  data = array(data,dtype=float64).squeeze()

  # ------ Subtract zero frequency
  ddata = data - mean(data)

  # ------ Handle frequency array
  if len(args) == 2:
    dt    = median(diff(time))
    nyq   = 1./(2*dt)
    tdiff = nmax(time) - nmin(time)
    freq  = arange(0,nyq,1./(ofac*tdiff),dtype=float64)
  else:
    freq = array(freq,dtype=float64).squeeze()

  # ------ Handle weights
  if weights is not None:
    weights = array(weights,dtype=float64).squeeze()

  # ------ Calculate Power Spectrum
  alpha, beta, power = chunkeval(time,ddata,freq,weights=weights)

  # ------ Ending time
  if timeit:
    print 'spectrum.py: powerspectrum finished in %3.1f seconds' % (systemtime() - t0)
  
  # ------ Return
  if len(args) == 2:
    return alpha, beta, power, freq
  else:
    return alpha, beta, power
예제 #25
0
def PropagateWavePacket(**args):
    #Set up problem
    prop = SetupProblem(**args)
    conf = prop.Config

    #Setup traveling wavepacket initial state
    f = lambda x: conf.Wavepacket.function(conf.Wavepacket, x)
    bspl = prop.psi.GetRepresentation().GetRepresentation(0).GetBSplineObject()
    c = bspl.ExpandFunctionInBSplines(f)
    prop.psi.GetData()[:] = c
    prop.psi.Normalize()
    initialPsi = prop.psi.Copy()

    #Get x-grid
    subProp = prop.Propagator.SubPropagators[0]
    subProp.InverseTransform()
    grid = prop.psi.GetRepresentation().GetLocalGrid(0)
    subProp.ForwardTransform()

    #Setup equispaced x grid
    x_min = conf.BSplineRepresentation.xmin
    x_max = conf.BSplineRepresentation.xmax
    grid_eq = linspace(x_min, x_max, grid.size)
    x_spacing = grid_eq[1] - grid_eq[0]

    #Set up fft grid
    k_spacing = 1.0 / grid_eq.size
    k_max = pi / x_spacing
    k_min = -k_max
    k_spacing = (k_max - k_min) / grid_eq.size
    grid_fft = zeros(grid_eq.size, dtype=double)
    grid_fft[:grid_eq.size / 2 + 1] = r_[0.0:k_max:k_spacing]
    grid_fft[grid_eq.size / 2:] = r_[k_min:0.0:k_spacing]
    print "Momentum space resolution = %f a.u." % k_spacing

    k0 = conf.Wavepacket.k0
    k0_trunk = 5 * k0
    trunkIdx = list(nwhere(abs(grid_fft) <= k0_trunk)[0])

    rcParams['interactive'] = True
    figure()
    p1 = subplot(211)
    p2 = subplot(212)
    p1.hold(False)

    psi_eq = zeros((grid.size), dtype=complex)

    for t in prop.Advance(40):
        print "t = %f, norm = %.15f, P = %.15f " % \
         ( t, prop.psi.GetNorm(), abs(prop.psi.InnerProduct(initialPsi))**2 )
        sys.stdout.flush()

        subProp.InverseTransform()
        p1.plot(grid, abs(prop.psi.GetData())**2)
        subProp.ForwardTransform()
        bspl.ConstructFunctionFromBSplineExpansion(prop.psi.GetData(), grid_eq,
                                                   psi_eq)
        psi_fft = (abs(fft.fft(psi_eq))**2)
        psi_fft_max = nmax(psi_fft)
        psi_fft /= psi_fft_max

        #Plot momentum space |psi|**2
        p2.hold(False)
        p2.semilogy(grid_fft[trunkIdx], psi_fft[trunkIdx] + 1e-21)
        p2.hold(True)
        p2.semilogy([0, 0], [1e-20, psi_fft_max], 'r-')
        p2.semilogy([-k0, -k0], [1e-20, psi_fft_max], 'g--')
        p2.semilogy([k0, k0], [1e-20, psi_fft_max], 'g--')

        #Set subplot axis
        p1.axis([grid[0], grid[-1], 0, 0.1])
        p2.axis([-k0_trunk, k0_trunk, 1e-20, psi_fft_max])
        show()

    hold(True)

    return prop
예제 #26
0
    def plot( self, ax ):

        exec_time_arr = self.exec_time_arr
        n_int_arr = self.n_int_arr[0, :]
        real_memsize_arr = self.real_memsize_arr[0, :]

        rand_arr = arange( len( self.rand_list ) ) + 1
        width = 0.45

        if exec_time_arr.shape[0] == 1:
            shift = width / 2.0
            ax.bar( rand_arr - shift, exec_time_arr[0, :], width, color = 'lightgrey' )

        elif self.exec_time_arr.shape[0] == 2:
            max_exec_time = nmax( exec_time_arr )

            ax.set_ylabel( '$\mathrm{execution \, time \, [sec]}$', size = 20 )
            ax.set_xlabel( '$n_{\mathrm{rnd}}  \;-\; \mathrm{number \, of \, random \, parameters}$', size = 20 )

            ax.bar( rand_arr - width, exec_time_arr[0, :], width,
                    hatch = '/', color = 'white', label = 'C' ) # , color = 'lightgrey' )
            ax.bar( rand_arr, exec_time_arr[1, :], width,
                    color = 'lightgrey', label = 'numpy' )

            yscale = 1.25
            ax_xlim = rand_arr[-1] + 1
            ax_ylim = max_exec_time * yscale

            ax.set_xlim( 0, ax_xlim )
            ax.set_ylim( 0, ax_ylim )

            ax2 = ax.twinx()
            ydata = exec_time_arr[1, :] / exec_time_arr[0, :]
            ax2.plot( rand_arr, ydata, '-o', color = 'black',
                      linewidth = 1, label = 'numpy/C' )

            ax2.plot( [rand_arr[0] - 1, rand_arr[-1] + 1], [1, 1], '-' )
            ax2.set_ylabel( '$\mathrm{time}(  \mathsf{numpy}  ) / \mathrm{ time }(\mathsf{C}) \; [-]$', size = 20 )
            ax2_ylim = nmax( ydata ) * yscale
            ax2_xlim = rand_arr[-1] + 1
            ax2.set_ylim( 0, ax2_ylim )
            ax2.set_xlim( 0, ax2_xlim )

            ax.set_xticks( rand_arr )
            ax.set_xticklabels( rand_arr, size = 14 )
            xticks = [ '%.2g' % n_int for n_int in n_int_arr ]
            ax3 = ax.twiny()
            ax3.set_xlim( 0, rand_arr[-1] + 1 )
            ax3.set_xticks( rand_arr )
            ax3.set_xlabel( '$n_{\mathrm{int}}$', size = 20 )
            ax3.set_xticklabels( xticks, rotation = 30 )

            'set the tick label size of the lower X axis'
            X_lower_tick = 14
            xt = ax.get_xticklabels()
            for t in xt:
                t.set_fontsize( X_lower_tick )

            'set the tick label size of the upper X axis'
            X_upper_tick = 12
            xt = ax3.get_xticklabels()
            for t in xt:
                t.set_fontsize( X_upper_tick )

            'set the tick label size of the Y axes'
            Y_tick = 14
            yt = ax2.get_yticklabels() + ax.get_yticklabels()
            for t in yt:
                t.set_fontsize( Y_tick )

            'set the legend position and font size'
            leg_fontsize = 16
            leg = ax.legend( loc = ( 0.02, 0.83 ) )
            for t in leg.get_texts():
                t.set_fontsize( leg_fontsize )
            leg = ax2.legend( loc = ( 0.705, 0.90 ) )
            for t in leg.get_texts():
                t.set_fontsize( leg_fontsize )
예제 #27
0
    def plot(self, ax):

        exec_time_arr = self.exec_time_arr
        n_int_arr = self.n_int_arr[0, :]
        real_memsize_arr = self.real_memsize_arr[0, :]

        rand_arr = arange(len(self.rand_list)) + 1
        width = 0.45

        if exec_time_arr.shape[0] == 1:
            shift = width / 2.0
            ax.bar(rand_arr - shift,
                   exec_time_arr[0, :],
                   width,
                   color='lightgrey')

        elif self.exec_time_arr.shape[0] == 2:
            max_exec_time = nmax(exec_time_arr)

            ax.set_ylabel('$\mathrm{execution \, time \, [sec]}$', size=20)
            ax.set_xlabel(
                '$n_{\mathrm{rnd}}  \;-\; \mathrm{number \, of \, random \, parameters}$',
                size=20)

            ax.bar(rand_arr - width,
                   exec_time_arr[0, :],
                   width,
                   hatch='/',
                   color='white',
                   label='C')  # , color = 'lightgrey' )
            ax.bar(rand_arr,
                   exec_time_arr[1, :],
                   width,
                   color='lightgrey',
                   label='numpy')

            yscale = 1.25
            ax_xlim = rand_arr[-1] + 1
            ax_ylim = max_exec_time * yscale

            ax.set_xlim(0, ax_xlim)
            ax.set_ylim(0, ax_ylim)

            ax2 = ax.twinx()
            ydata = exec_time_arr[1, :] / exec_time_arr[0, :]
            ax2.plot(rand_arr,
                     ydata,
                     '-o',
                     color='black',
                     linewidth=1,
                     label='numpy/C')

            ax2.plot([rand_arr[0] - 1, rand_arr[-1] + 1], [1, 1], '-')
            ax2.set_ylabel(
                '$\mathrm{time}(  \mathsf{numpy}  ) / \mathrm{ time }(\mathsf{C}) \; [-]$',
                size=20)
            ax2_ylim = nmax(ydata) * yscale
            ax2_xlim = rand_arr[-1] + 1
            ax2.set_ylim(0, ax2_ylim)
            ax2.set_xlim(0, ax2_xlim)

            ax.set_xticks(rand_arr)
            ax.set_xticklabels(rand_arr, size=14)
            xticks = ['%.2g' % n_int for n_int in n_int_arr]
            ax3 = ax.twiny()
            ax3.set_xlim(0, rand_arr[-1] + 1)
            ax3.set_xticks(rand_arr)
            ax3.set_xlabel('$n_{\mathrm{int}}$', size=20)
            ax3.set_xticklabels(xticks, rotation=30)

            'set the tick label size of the lower X axis'
            X_lower_tick = 14
            xt = ax.get_xticklabels()
            for t in xt:
                t.set_fontsize(X_lower_tick)

            'set the tick label size of the upper X axis'
            X_upper_tick = 12
            xt = ax3.get_xticklabels()
            for t in xt:
                t.set_fontsize(X_upper_tick)

            'set the tick label size of the Y axes'
            Y_tick = 14
            yt = ax2.get_yticklabels() + ax.get_yticklabels()
            for t in yt:
                t.set_fontsize(Y_tick)

            'set the legend position and font size'
            leg_fontsize = 16
            leg = ax.legend(loc=(0.02, 0.83))
            for t in leg.get_texts():
                t.set_fontsize(leg_fontsize)
            leg = ax2.legend(loc=(0.705, 0.90))
            for t in leg.get_texts():
                t.set_fontsize(leg_fontsize)
예제 #28
0
def windowfunction(*args,**kw):
  """Calling Sequence:
    power, [freq] = windowfunction(time, [freq,] weights=None, timeit=True, ofac=4, wcf=None)

  Input:
    time       : Time array
    freq       : Frequency array with frequencies to be evaluated. If not present one will be generated

  Output:
    power      : Power Spectrum
    freq       : Generated frequency array (if not present in input)

  Keywords:
    weights    : Array with weights to be used. If this array is not present no weighting will be used
    timeit     : If true prints timing information
    ofac       : Oversampling parameter
    wcf        : If given this will the be frequency for the pseudo window. If not given a central frequency is used

  Description:
    The window function for a Lomb periodogram is not strictly defined. This functions calculates a pseudo-window.
  """
  # ------ Starting time
  t0 = systemtime()

  # ------- Handle keywords
  weights = kw.get('weights',None)
  timeit  = kw.get('timeit',True)
  wcf     = kw.get('wcf',None) #Window center frequency
  ofac    = kw.get('ofac',4.)
  
  # ------ Handle arguments
  if len(args) == 1:
    time = args[0]
  elif len(args) == 2:
    time = args[0]
    freq = args[1]
  else:
    raise InputError('Wrong number of inputs')

  time = array(time,dtype=float64).squeeze()

  # ------ Handle frequency array
  if len(args) == 1:
    dt    = median(diff(time))
    nyq   = 1./(2*dt)
    tdiff = nmax(time) - nmin(time)
    freq  = arange(0.01*nyq,0.99*nyq,1./(ofac*tdiff),dtype=float64)
  else:
    freq = array(freq,dtype=float64).squeeze()
  
  # ------ Handle weights
  if weights is not None:
    weights = array(weights,dtype=float64).squeeze()

  # ------ Estimate Window Center Frequency
  if wcf is None:
    wcf = median(freq)

  # ------ Make data
  arg  = 2*pi*wcf*time
  #data = 0.5 * (np.sin(arg) + np.cos(arg))
  
  # ------ Calculate Power Spectrum (Only power output)
  power = 0.5*(chunkeval(time,sin(arg),freq,weights=weights)[-1] + chunkeval(time,cos(arg),freq,weights=weights)[-1])

  # ------ Ending time
  if timeit:
    print 'spectrum.py: windowfunction finished in %3.1f seconds' % (systemtime() - t0)

  # ------ Return
  if len(args) == 1:
    return power, freq
  else:
    return power

#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#
  """Calling Sequence:
예제 #29
0
파일: stests.py 프로젝트: khharut/pyhistavg
def TestMV(itseries, ind, lobs):
    """
    Performs mean and varaince change statistical tests. Variance change test
    based on F-test(F-distribution), and mean test based on t-test (Student's
    distribution)

    Parameters
    ----------
    itseries: 1D array
        one dimensional array of time series

    ind: integer
        index in time series around which ARIMA(1,0,0) process performed

    lobs: integer
        maximal number of elements to left and right sides around ind index on
        which ARIMA(1,0,0) is performed

    Returns
    -------
    cpvalue: numeric
        0.1, 0.05 or 0, 0.1 if ind around ind mean and variance change
        significanly, 0.05 if only mean or variance changes significantly
        0 in case when none of them changes
    """
    cpvalue = 0
    tstart = max(ind - lobs + 1, 0)  # start of test interval around ind index
    tend = min(ind + lobs + 1,
               len(itseries))  # end of test interval around ind index
    if (len(itseries[tstart:(ind + 1)]) > 1) and (len(itseries[(ind + 1):tend])
                                                  > 1):
        # it is needed to have at least one element in both left and right sides
        # around ind index
        if (var(itseries[tstart:(ind + 1)], ddof=1)
                == 0) and (var(itseries[(ind + 1):tend], ddof=1) != 0):
            cpvalue = 0.1
        # in case when variance is zero on left side and non zero on rigth
        # side then definitely variance changes around ind index
        if (var(itseries[tstart:(ind + 1)], ddof=1) != 0) and (var(
                itseries[(ind + 1):tend], ddof=1) == 0):
            cpvalue = 0.1
        if (var(itseries[tstart:(ind + 1)], ddof=1) *
                var(itseries[(ind + 1):tend], ddof=1)) > 0:
            intseries = array(itseries[tstart:tend])  # slicing test data
            n = len(intseries)
            mid_ind = (n / 2) - 1  # ind element position in sliced data
            all_means = emean(intseries, min_periods=1)
            all_vars = evar(intseries, min_periods=1)
            rev_all_means = emean(intseries[::-1], min_periods=1)
            rev_all_vars = evar(intseries[::-1], min_periods=1)
            test_lens = arange((mid_ind + 1), (n + 1))
            if (rev_all_vars[mid_ind] > 0) and (all_vars[mid_ind] > 0):
                z = all_vars[mid_ind] / rev_all_vars[mid_ind]
                rz = 1 / z
            else:
                z = inf
                rz = 0.0
            ## variance change F-test with reliabilty value 99.8% (0.1%-99.9%)
            if (z > f.ppf(1 - 0.001, mid_ind, mid_ind)) or (z < f.ppf(
                    0.001, mid_ind, mid_ind)):
                cpvalue = 0.05
            if (rz > f.ppf(1 - 0.001, mid_ind, mid_ind)) or (rz < f.ppf(
                    0.001, mid_ind, mid_ind)):
                cpvalue = 0.05
            ## calculation of t-test statistics
            Sx_y = sqrt(
                ((mid_ind * all_vars[mid_ind] + test_lens * all_vars[mid_ind:])
                 * (mid_ind + test_lens)) /
                ((mid_ind + test_lens - 2) * mid_ind * test_lens))
            t_jn = nabs((all_means[mid_ind] - all_means[mid_ind:]) / Sx_y)

            rSx_y = sqrt(((mid_ind * rev_all_vars[mid_ind] +
                           test_lens * rev_all_vars[mid_ind:]) *
                          (mid_ind + test_lens)) /
                         ((mid_ind + test_lens - 2) * mid_ind * test_lens))
            rt_jn = nabs(
                (rev_all_means[mid_ind] - rev_all_means[mid_ind:]) / rSx_y)

            t_stat = nmax((t_jn, rt_jn))
            dfree = n - 2
            # mean change t-test with reliabilty value 99.8% (0.1%-99.9%)
            if t_stat > t.ppf(1 - 0.001, dfree):
                cpvalue = cpvalue + 0.05
        if cpvalue > 0:
            # in case if cpvalue  is 0.1 then checking if detected changepoint is
            # significant by calculating sindic value for interval
            sindic = abs(
                std(itseries[tstart:(ind - 1)], ddof=1) -
                std(itseries[(ind + 1):tend], ddof=1))
            sindic = sindic * mean(itseries[tstart:tend])
            if sindic is not None:
                if sindic <= 0.03:
                    cpvalue = 0  # if sindic is less than 0.03 then changepoint is not significant
    return cpvalue
예제 #30
0
 def max(self):
     return(nmax(self.prev))
예제 #31
0
def plot(
    data, name, format='png', suffix='', path='./', common_scale=True, datarange=(None, None),
        new=True, last=True, rows=1, num=1, fontmap=None, verbose=1):
    """
    Generates summary plots for nodes of a given PyMC object.

    :Arguments:
        data: PyMC object, trace or array
            A trace from an MCMC sample or a PyMC object with one or more traces.

        name: string
            The name of the object.

        format (optional): string
            Graphic output format (defaults to png).

        suffix (optional): string
            Filename suffix.

        path (optional): string
            Specifies location for saving plots (defaults to local directory).

        common_scale (optional): bool
            Specifies whether plots of multivariate nodes should be on the same scale
            (defaults to True).

    """
    if fontmap is None:
        fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}

    # If there is only one data array, go ahead and plot it ...
    if ndim(data) == 1:

        if verbose > 0:
            print_('Plotting', name)

        # If new plot, generate new frame
        if new:

            figure(figsize=(10, 6))

        # Call trace
        trace(
            data,
            name,
            datarange=datarange,
            rows=rows * 2,
            columns=2,
            num=num + 3 * (num - 1),
            last=last,
            fontmap=fontmap)
        # Call autocorrelation
        autocorrelation(
            data,
            name,
            rows=rows * 2,
            columns=2,
            num=num + 3 * (
                num - 1) + 2,
            last=last,
            fontmap=fontmap)
        # Call histogram
        histogram(
            data,
            name,
            datarange=datarange,
            rows=rows,
            columns=2,
            num=num * 2,
            last=last,
            fontmap=fontmap)

        if last:
            if not os.path.exists(path):
                os.mkdir(path)
            if not path.endswith('/'):
                path += '/'
            savefig("%s%s%s.%s" % (path, name, suffix, format))

    else:
        # ... otherwise plot recursively
        tdata = swapaxes(data, 0, 1)

        datarange = (None, None)
        # Determine common range for plots
        if common_scale:
            datarange = (nmin(tdata), nmax(tdata))

        # How many rows?
        _rows = min(4, len(tdata))

        for i in range(len(tdata)):

            # New plot or adding to existing?
            _new = not i % _rows
            # Current subplot number
            _num = i % _rows + 1
            # Final subplot of current figure?
            _last = (_num == _rows) or (i == len(tdata) - 1)

            plot(
                tdata[i],
                name + '_' + str(i),
                format=format,
                path=path,
                common_scale=common_scale,
                datarange=datarange,
                suffix=suffix,
                new=_new,
                last=_last,
                rows=_rows,
                num=_num)
예제 #32
0
def plot_from_dill(suffix=''):
    unmitigated = gr_runs[0][0]
    for ig, models in enumerate(gr_runs):
        axes = axis[ig, 0]
        for ic in range(len(models) - 1, 0, -1):
            m = models[ic]
            m.plot_cases(
                axes,
                label='{:.0f}% compliance'.format(
                    100 * m.setup['npi']['compliance']),
                colour=[0.5, 0.5, 0.5 + (0.5 * m.setup['npi']['compliance'])])
            percred[ig, ic] = 100.0 * m.peak_ratio(unmitigated)
        baseline = models[0]
        baseline.plot_cases(axes, label='Baseline', colour=[0, 0, 0])
        yup = 1.05 * unmitigated.peak_value
        # Draw intervention lines
        npi_start, npi_end = baseline.setup['npi']['start'], baseline.setup[
            'npi']['end']
        axes.plot([npi_start, npi_start], [0, yup], ls='--', c='k')
        axes.plot([npi_end, npi_end], [0, yup], ls='--', c='k')

        axes.set_xlim([0, baseline.setup['final_time']])
        axes.set_ylim([0, yup])
        axes.set_xlabel('Time (days)')
        axes.set_ylabel('Number of Cases')
        axes.title.set_text('Global Reduction {:.0f}%'.format(
            100 * baseline.setup['npi']['global_reduction']))
        axes = axis[ig, 1]
        for ic in range(len(models) - 1, 0, -1):
            m = models[ic]
            m.plot_person_days_of_isolation(
                axes,
                label='{:.0f}%'.format(100 * m.setup['npi']['compliance']),
                colour=[0.3, 0.3, 0.3 + (0.7 * m.setup['npi']['compliance'])])
            persdays[ig, ic] = m.persdays
        baseline.plot_person_days_of_isolation(axes,
                                               label='Baseline',
                                               colour=[0, 0, 0])
        axes.set_xlim([0, baseline.setup['final_time']])

        axes.set_ylim([0, 1.05 * models[-1].max_person_days_of_isolation])
        if (ig == 0):
            axes.legend(bbox_to_anchor=(1.04, 1),
                        loc='upper left',
                        title='Compliance')
        axes.set_xlabel('Time (days)')
        axes.set_ylabel('Person-Days Isolation')
        axes.title.set_text('Global Reduction {:.0f}%'.format(
            100 * baseline.global_reduction))
    fig.tight_layout()
    fig.savefig('./time_series{0}.pdf'.format(suffix))

    fig, axis = subplots(1, 2, figsize=(8, 3.75))
    comply_range = [m.setup['npi']['compliance'] for m in gr_runs[0]]
    for ig, models in enumerate(gr_runs):
        axis[0].plot(comply_range,
                     percred[ig, :],
                     label='{:.0f}%'.format(100 * models[0].global_reduction))
    axis[0].set_xlabel('Compliance with Isolation')
    axis[0].set_ylabel('Percentage of Baseline peak')
    axis[0].title.set_text('Individual isolation: Mitigation')
    axis[0].set_xlim([0, 1])
    axis[0].set_ylim([0, 100])
    axis[0].legend(title="Global Reduction")
    for ig, models in enumerate(gr_runs):
        axis[1].plot(comply_range, persdays[ig, :])
    axis[1].set_xlabel('Compliance with Isolation')
    axis[1].set_ylabel('Number of Person-Days Isolation')
    axis[1].title.set_text('Individual isolation: Cost')
    axis[1].set_xlim([0, 1])
    axis[1].set_ylim([0, ceil(nmax(persdays))])
    fig.tight_layout()
    fig.savefig('./mit_costs{0}.pdf'.format(suffix))

    fig, axis = subplots(2, 4, figsize=(10, 4))
    for n in range(1, unmitigated.setup.nmax + 1):
        axes = axis[(n - 1) // 4, (n - 1) % 4]
        for models in gr_runs:
            axes.plot(comply_range, [m.prav[n - 1] for m in models],
                      label='{:.0f}%'.format(100 * models[0].global_reduction))
            axes.set_xlim([0, 1.0])
            axes.set_ylim([0, 0.5])
            axes.set_xlabel('Compliance')
            axes.set_ylabel('Prob(Avoid)')
            axes.title.set_text('Household size {0:d}'.format(int(n)))
        if n == nmax:
            fig.legend(bbox_to_anchor=(1.04, 1),
                       loc="upper left",
                       title="Global Reduction")
    fig.tight_layout()
    fig.savefig('prob_avoid{0}.pdf'.format(suffix))
예제 #33
0
 def max_person_days_of_isolation(self):
     return (self.spec['Npop'] / self.nbar) * nmax(self.pdi)
예제 #34
0
 def peak_value(self):
     return (self.spec['Npop'] / self.nbar) * nmax(self.prev)
예제 #35
0
 def peak_ratio(self, other):
     return nmax(self.prev) / nmax(other.prev)
예제 #36
0
파일: Matplot.py 프로젝트: along1x/pymc
def summary_plot(pymc_obj, name='model', format='png',  suffix='-summary', path='./', alpha=0.05, quartiles=True, rhat=True, main=None, chain_spacing=0.05, vline_pos=0):
    """
    Model summary plot
    
    :Arguments:
        pymc_obj: PyMC object, trace or array
            A trace from an MCMC sample or a PyMC object with one or more traces.

        name (optional): string
            The name of the object.

        format (optional): string
            Graphic output format (defaults to png).

        suffix (optional): string
            Filename suffix.

        path (optional): string
            Specifies location for saving plots (defaults to local directory).
            
        alpha (optional): float
            Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
            
        rhat (optional): bool
            Flag for plotting Gelman-Rubin statistics. Requires 2 or more 
            chains (defaults to True).
            
        main (optional): string
            Title for main plot. Passing False results in titles being 
            suppressed; passing False (default) results in default titles.
            
        chain_spacing (optional): float
            Plot spacing between chains (defaults to 0.05).
            
        vline_pos (optional): numeric
            Location of vertical reference line (defaults to 0).
    
    """
    
    if not gridspec:
        print '\nYour installation of matplotlib is not recent enough to support summary_plot; this function is disabled until matplotlib is updated.'
        return
    
    # Quantiles to be calculated
    quantiles = [100*alpha/2, 50, 100*(1-alpha/2)]
    if quartiles:
        quantiles = [100*alpha/2, 25, 50, 75, 100*(1-alpha/2)]

    # Range for x-axis
    plotrange = None
    
    # Number of chains
    chains = None
    
    # Gridspec
    gs = None
    
    # Subplots
    interval_plot = None
    rhat_plot = None
    
    try:
        # First try Model type
        vars = pymc_obj._variables_to_tally
        
    except AttributeError:
        
        # Assume an iterable
        vars = pymc_obj
    
    # Empty list for y-axis labels
    labels = []
    # Counter for current variable
    var = 1
    
    # Make sure there is something to print
    if all([v._plot==False for v in vars]):
        print 'No variables to plot'
        return
    
    for variable in vars:

        # If plot flag is off, do not print
        if variable._plot==False:
            continue
            
        # Extract name
        varname = variable.__name__

        # Retrieve trace(s)
        i = 0
        traces = []
        while True:
           try:
               #traces.append(pymc_obj.trace(varname, chain=i)[:])
               traces.append(variable.trace(chain=i))
               i+=1
           except KeyError:
               break
               
        chains = len(traces)
        
        if gs is None:
            # Initialize plot
            if rhat and chains>1:
                gs = gridspec.GridSpec(1, 2, width_ratios=[3,1])

            else:
                
                gs = gridspec.GridSpec(1, 1)
                
            # Subplot for confidence intervals
            interval_plot = subplot(gs[0])
                
        # Get quantiles
        data = [calc_quantiles(d, quantiles) for d in traces]
        data = [[d[q] for q in quantiles] for d in data]
        
        # Ensure x-axis contains range of current interval
        if plotrange:
            plotrange = [min(plotrange[0], nmin(data)), max(plotrange[1], nmax(data))]
        else:
            plotrange = [nmin(data), nmax(data)]
        
        try:
            # First try missing-value stochastic
            value = variable.get_stoch_value()
        except AttributeError:
            # All other variable types
            value = variable.value

        # Number of elements in current variable
        k = size(value)
        
        # Append variable name(s) to list
        if k>1:
            names = var_str(varname, shape(value))
            labels += names
        else:
            labels.append('\n'.join(varname.split('_')))
            
        # Add spacing for each chain, if more than one
        e = [0] + [(chain_spacing * ((i+2)/2))*(-1)**i for i in range(chains-1)]
        
        # Loop over chains
        for j,quants in enumerate(data):
            
            # Deal with multivariate nodes
            if k>1:

                for i,q in enumerate(transpose(quants)):
                    
                    # Y coordinate with jitter
                    y = -(var+i) + e[j]
                    
                    if quartiles:
                        # Plot median
                        pyplot(q[2], y, 'bo', markersize=4)
                        # Plot quartile interval
                        errorbar(x=(q[1],q[3]), y=(y,y), linewidth=2, color="blue")
                        
                    else:
                        # Plot median
                        pyplot(q[1], y, 'bo', markersize=4)

                    # Plot outer interval
                    errorbar(x=(q[0],q[-1]), y=(y,y), linewidth=1, color="blue")

            else:
                
                # Y coordinate with jitter
                y = -var + e[j]
                
                if quartiles:
                    # Plot median
                    pyplot(quants[2], y, 'bo', markersize=4)
                    # Plot quartile interval
                    errorbar(x=(quants[1],quants[3]), y=(y,y), linewidth=2, color="blue")
                else:
                    # Plot median
                    pyplot(quants[1], y, 'bo', markersize=4)
                
                # Plot outer interval
                errorbar(x=(quants[0],quants[-1]), y=(y,y), linewidth=1, color="blue")
            
        # Increment index
        var += k
        
    # Define range of y-axis
    ylim(-var+0.5, -0.5)
    
    datarange = plotrange[1] - plotrange[0]
    xlim(plotrange[0] - 0.05*datarange, plotrange[1] + 0.05*datarange)
    
    # Add variable labels
    ylabels = yticks([-(l+1) for l in range(len(labels))], labels)        
            
    # Add title
    if main is not False:
        plot_title = main or str(int((1-alpha)*100)) + "% Credible Intervals"
        title(plot_title)
    
    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False
    
    for loc, spine in interval_plot.spines.iteritems():
        if loc in ['bottom','top']:
            pass
            #spine.set_position(('outward',10)) # outward by 10 points
        elif loc in ['left','right']:
            spine.set_color('none') # don't draw spine
      
    # Reference line
    axvline(vline_pos, color='k', linestyle='--')  
        
    # Genenerate Gelman-Rubin plot
    if rhat and chains>1:

        from diagnostics import gelman_rubin
        
        # If there are multiple chains, calculate R-hat
        rhat_plot = subplot(gs[1])
        
        if main is not False:
            title("R-hat")
        
        # Set x range
        xlim(0.9,2.1)
        
        # X axis labels
        xticks((1.0,1.5,2.0), ("1", "1.5", "2+"))
        yticks([-(l+1) for l in range(len(labels))], "")
        
        # Calculate diagnostic
        try:
            R = gelman_rubin(pymc_obj)
        except ValueError:
            R = {}
            for variable in vars:
                R[variable.__name__] = gelman_rubin(variable)
        
        i = 1
        for variable in vars:
            
            if variable._plot==False:
                continue
            
            # Extract name
            varname = variable.__name__
            
            try:
                value = variable.get_stoch_value()
            except AttributeError:
                value = variable.value
                
            k = size(value)
            
            if k>1:
                pyplot([min(r, 2) for r in R[varname]], [-(j+i) for j in range(k)], 'bo', markersize=4)
            else:
                pyplot(min(R[varname], 2), -i, 'bo', markersize=4)
    
            i += k
            
        # Define range of y-axis
        ylim(-i+0.5, -0.5)
        
        # Remove ticklines on y-axes
        for ticks in rhat_plot.yaxis.get_major_ticks():
            ticks.tick1On = False
            ticks.tick2On = False
        
        for loc, spine in rhat_plot.spines.iteritems():
            if loc in ['bottom','top']:
                pass
                #spine.set_position(('outward',10)) # outward by 10 points
            elif loc in ['left','right']:
                spine.set_color('none') # don't draw spine
        
    savefig("%s%s%s.%s" % (path, name, suffix, format))                
예제 #37
0
def summary_plot(
    pymc_obj, name='model', format='png', suffix='-summary', path='./',
    alpha=0.05, chain=None, quartiles=True, hpd=True, rhat=True, main=None,
    xlab=None, x_range=None, custom_labels=None, chain_spacing=0.05, vline_pos=0):
    """
    Model summary plot

    Generates a "forest plot" of 100*(1-alpha)% credible intervals for either the
    set of nodes in a given model, or a specified set of nodes.

    :Arguments:
        pymc_obj: PyMC object, trace or array
            A trace from an MCMC sample or a PyMC object with one or more traces.

        name (optional): string
            The name of the object.

        format (optional): string
            Graphic output format (defaults to png).

        suffix (optional): string
            Filename suffix.

        path (optional): string
            Specifies location for saving plots (defaults to local directory).

        alpha (optional): float
            Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05).
            
        chain (optional): int
            Where there are multiple chains, specify a particular chain to plot.
            If not specified (chain=None), all chains are plotted.

        quartiles (optional): bool
            Flag for plotting the interquartile range, in addition to the
            (1-alpha)*100% intervals (defaults to True).

        hpd (optional): bool
            Flag for plotting the highest probability density (HPD) interval
            instead of the central (1-alpha)*100% interval (defaults to True).

        rhat (optional): bool
            Flag for plotting Gelman-Rubin statistics. Requires 2 or more
            chains (defaults to True).

        main (optional): string
            Title for main plot. Passing False results in titles being
            suppressed; passing False (default) results in default titles.

        xlab (optional): string
            Label for x-axis. Defaults to no label

        x_range (optional): list or tuple
            Range for x-axis. Defaults to matplotlib's best guess.

        custom_labels (optional): list
            User-defined labels for each node. If not provided, the node
            __name__ attributes are used.

        chain_spacing (optional): float
            Plot spacing between chains (defaults to 0.05).

        vline_pos (optional): numeric
            Location of vertical reference line (defaults to 0).

    """

    if not gridspec:
        print_(
            '\nYour installation of matplotlib is not recent enough to support summary_plot; this function is disabled until matplotlib is updated.')
        return

    # Quantiles to be calculated
    quantiles = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)]
    if quartiles:
        quantiles = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]

    # Range for x-axis
    plotrange = None

    # Gridspec
    gs = None

    # Subplots
    interval_plot = None
    rhat_plot = None

    try:
        # First try Model type
        vars = pymc_obj._variables_to_tally

    except AttributeError:

        try:

            # Try a database object
            vars = pymc_obj._traces

        except AttributeError:

            if isinstance(pymc_obj, Variable):
                vars = [pymc_obj]
            else:
                # Assume an iterable
                vars = pymc_obj

    from .diagnostics import gelman_rubin

    # Calculate G-R diagnostics
    if rhat:
        try:
            R = {}
            for variable in vars:
                R[variable.__name__] = gelman_rubin(variable)
        except (ValueError, TypeError):
            print(
                'Could not calculate Gelman-Rubin statistics. Requires multiple chains of equal length.')
            rhat = False

    # Empty list for y-axis labels
    labels = []
    # Counter for current variable
    var = 1

    # Make sure there is something to print
    if all([v._plot == False for v in vars]):
        print_('No variables to plot')
        return

    for variable in vars:

        # If plot flag is off, do not print
        if variable._plot == False:
            continue

        # Extract name
        varname = variable.__name__

        # Retrieve trace(s)
        if chain is not None:
            chains = 1
            traces = [variable.trace(chain=chain)]
        else:
            chains = variable.trace.db.chains
            traces = [variable.trace(chain=i) for i in range(chains)]

        if gs is None:
            # Initialize plot
            if rhat and chains > 1:
                gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])

            else:

                gs = gridspec.GridSpec(1, 1)

            # Subplot for confidence intervals
            interval_plot = subplot(gs[0])

        # Get quantiles
        data = [calc_quantiles(d, quantiles) for d in traces]
        if hpd:
            # Substitute HPD interval
            for i, d in enumerate(traces):
                hpd_interval = calc_hpd(d, alpha)
                data[i][quantiles[0]] = hpd_interval[0]
                data[i][quantiles[-1]] = hpd_interval[1]

        data = [[d[q] for q in quantiles] for d in data]
        # Ensure x-axis contains range of current interval
        if plotrange:
            plotrange = [min(
                         plotrange[0],
                         nmin(data)),
                         max(plotrange[1],
                             nmax(data))]
        else:
            plotrange = [nmin(data), nmax(data)]

        try:
            # First try missing-value stochastic
            value = variable.get_stoch_value()
        except AttributeError:
            # All other variable types
            value = variable.value

        # Number of elements in current variable
        k = size(value)
        
        # Append variable name(s) to list
        if k > 1:
            names = var_str(varname, shape(value)[int(shape(value)[0]==1):])
            labels += names
        else:
            labels.append(varname)
            # labels.append('\n'.join(varname.split('_')))

        # Add spacing for each chain, if more than one
        e = [0] + [(chain_spacing * ((i + 2) / 2)) * (
            -1) ** i for i in range(chains - 1)]

        # Loop over chains
        for j, quants in enumerate(data):

            # Deal with multivariate nodes
            if k > 1:
                ravelled_quants = list(map(ravel, quants))
                
                for i, quant in enumerate(transpose(ravelled_quants)):

                    q = ravel(quant)
                    
                    # Y coordinate with jitter
                    y = -(var + i) + e[j]

                    if quartiles:
                        # Plot median
                        pyplot(q[2], y, 'bo', markersize=4)
                        # Plot quartile interval
                        errorbar(
                            x=(q[1],
                                q[3]),
                            y=(y,
                                y),
                            linewidth=2,
                            color="blue")

                    else:
                        # Plot median
                        pyplot(q[1], y, 'bo', markersize=4)

                    # Plot outer interval
                    errorbar(
                        x=(q[0],
                            q[-1]),
                        y=(y,
                            y),
                        linewidth=1,
                        color="blue")

            else:

                # Y coordinate with jitter
                y = -var + e[j]

                if quartiles:
                    # Plot median
                    pyplot(quants[2], y, 'bo', markersize=4)
                    # Plot quartile interval
                    errorbar(
                        x=(quants[1],
                            quants[3]),
                        y=(y,
                            y),
                        linewidth=2,
                        color="blue")
                else:
                    # Plot median
                    pyplot(quants[1], y, 'bo', markersize=4)

                # Plot outer interval
                errorbar(
                    x=(quants[0],
                        quants[-1]),
                    y=(y,
                        y),
                    linewidth=1,
                    color="blue")

        # Increment index
        var += k

    if custom_labels is not None:
        labels = custom_labels

    # Update margins
    left_margin = max([len(x) for x in labels]) * 0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

    # Define range of y-axis
    ylim(-var + 0.5, -0.5)

    datarange = plotrange[1] - plotrange[0]
    xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange)

    # Add variable labels
    yticks([-(l + 1) for l in range(len(labels))], labels)

    # Add title
    if main is not False:
        plot_title = main or str(int((
            1 - alpha) * 100)) + "% Credible Intervals"
        title(plot_title)

    # Add x-axis label
    if xlab is not None:
        xlabel(xlab)

    # Constrain to specified range
    if x_range is not None:
        xlim(*x_range)

    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False

    for loc, spine in six.iteritems(interval_plot.spines):
        if loc in ['bottom', 'top']:
            pass
            # spine.set_position(('outward',10)) # outward by 10 points
        elif loc in ['left', 'right']:
            spine.set_color('none')  # don't draw spine

    # Reference line
    axvline(vline_pos, color='k', linestyle='--')

    # Genenerate Gelman-Rubin plot
    if rhat and chains > 1:

        # If there are multiple chains, calculate R-hat
        rhat_plot = subplot(gs[1])

        if main is not False:
            title("R-hat")

        # Set x range
        xlim(0.9, 2.1)

        # X axis labels
        xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+"))
        yticks([-(l + 1) for l in range(len(labels))], "")

        i = 1
        for variable in vars:

            if variable._plot == False:
                continue

            # Extract name
            varname = variable.__name__

            try:
                value = variable.get_stoch_value()
            except AttributeError:
                value = variable.value

            k = size(value)

            if k > 1:
                pyplot([min(r, 2) for r in R[varname]], [-(j + i)
                                                         for j in range(k)], 'bo', markersize=4)
            else:
                pyplot(min(R[varname], 2), -i, 'bo', markersize=4)

            i += k

        # Define range of y-axis
        ylim(-i + 0.5, -0.5)

        # Remove ticklines on y-axes
        for ticks in rhat_plot.yaxis.get_major_ticks():
            ticks.tick1On = False
            ticks.tick2On = False

        for loc, spine in six.iteritems(rhat_plot.spines):
            if loc in ['bottom', 'top']:
                pass
                # spine.set_position(('outward',10)) # outward by 10 points
            elif loc in ['left', 'right']:
                spine.set_color('none')  # don't draw spine

    savefig("%s%s%s.%s" % (path, name, suffix, format))
예제 #38
0
파일: example.py 프로젝트: AtomAleks/PyProp
def PropagateWavePacket(**args):
	#Set up problem
	prop = SetupProblem(**args)
	conf = prop.Config

	#Setup traveling wavepacket initial state
	f = lambda x: conf.Wavepacket.function(conf.Wavepacket, x)
	bspl = prop.psi.GetRepresentation().GetRepresentation(0).GetBSplineObject()
	c = bspl.ExpandFunctionInBSplines(f)
	prop.psi.GetData()[:] = c
	prop.psi.Normalize()
	initialPsi = prop.psi.Copy()

	#Get x-grid
	subProp = prop.Propagator.SubPropagators[0]
	subProp.InverseTransform()
	grid = prop.psi.GetRepresentation().GetLocalGrid(0)
	subProp.ForwardTransform()

	#Setup equispaced x grid
	x_min = conf.BSplineRepresentation.xmin
	x_max = conf.BSplineRepresentation.xmax
	grid_eq = linspace(x_min, x_max, grid.size)
	x_spacing = grid_eq[1] - grid_eq[0]
	
	#Set up fft grid
	k_spacing = 1.0 / grid_eq.size
	k_max = pi / x_spacing
	k_min = -k_max
	k_spacing = (k_max - k_min) / grid_eq.size
	grid_fft = zeros(grid_eq.size, dtype=double)
 	grid_fft[:grid_eq.size/2+1] = r_[0.0:k_max:k_spacing]
   	grid_fft[grid_eq.size/2:] = r_[k_min:0.0:k_spacing]
	print "Momentum space resolution = %f a.u." % k_spacing
	
	k0 = conf.Wavepacket.k0
	k0_trunk = 5 * k0
	trunkIdx = list(nwhere(abs(grid_fft) <= k0_trunk)[0])

	rcParams['interactive'] = True
	figure()
	p1 = subplot(211)
	p2 = subplot(212)
	p1.hold(False)

	psi_eq = zeros((grid.size), dtype=complex)

	for t in prop.Advance(40):
		print "t = %f, norm = %.15f, P = %.15f " % \
			( t, prop.psi.GetNorm(), abs(prop.psi.InnerProduct(initialPsi))**2 )
		sys.stdout.flush()

		subProp.InverseTransform()
		p1.plot(grid, abs(prop.psi.GetData())**2)
		subProp.ForwardTransform()
		bspl.ConstructFunctionFromBSplineExpansion(prop.psi.GetData(), grid_eq, psi_eq)
		psi_fft = (abs(fft.fft(psi_eq))**2)
		psi_fft_max = nmax(psi_fft)
		psi_fft /= psi_fft_max

		#Plot momentum space |psi|**2
		p2.hold(False)
		p2.semilogy(grid_fft[trunkIdx], psi_fft[trunkIdx] + 1e-21)
		p2.hold(True)
		p2.semilogy([0,0], [1e-20, psi_fft_max], 'r-')
		p2.semilogy([-k0,-k0], [1e-20, psi_fft_max], 'g--')
		p2.semilogy([k0,k0], [1e-20, psi_fft_max], 'g--')

		#Set subplot axis
		p1.axis([grid[0],grid[-1],0,0.1])
		p2.axis([-k0_trunk, k0_trunk, 1e-20, psi_fft_max])
		show()

	hold(True)
		

	return prop
예제 #39
0
def discrepancy_plot(
    data, name='discrepancy', report_p=True, format='png', suffix='-gof', path='./',
        fontmap=None):
    '''
    Generate goodness-of-fit deviate scatter plot.
    
    :Arguments:
        data: list
            List (or list of lists for vector-valued variables) of discrepancy values, output
            from the `pymc.diagnostics.discrepancy` function .

        name: string
            The name of the plot.
            
        report_p: bool
            Flag for annotating the p-value to the plot.

        format (optional): string
            Graphic output format (defaults to png).

        suffix (optional): string
            Filename suffix (defaults to "-gof").

        path (optional): string
            Specifies location for saving plots (defaults to local directory).

        fontmap (optional): dict
            Font map for plot.
    
    '''

    if verbose > 0:
        print_('Plotting', name + suffix)

    if fontmap is None:
        fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4}

    # Generate new scatter plot
    figure()
    try:
        x, y = transpose(data)
    except ValueError:
        x, y = data
    scatter(x, y)

    # Plot x=y line
    lo = nmin(ravel(data))
    hi = nmax(ravel(data))
    datarange = hi - lo
    lo -= 0.1 * datarange
    hi += 0.1 * datarange
    pyplot((lo, hi), (lo, hi))

    # Plot options
    xlabel('Observed deviates', fontsize='x-small')
    ylabel('Simulated deviates', fontsize='x-small')

    if report_p:
        # Put p-value in legend
        count = sum(s > o for o, s in zip(x, y))
        text(lo + 0.1 * datarange, hi - 0.1 * datarange,
             'p=%.3f' % (count / len(x)), horizontalalignment='center',
             fontsize=10)

    # Save to file
    if not os.path.exists(path):
        os.mkdir(path)
    if not path.endswith('/'):
        path += '/'
    savefig("%s%s%s.%s" % (path, name, suffix, format))
예제 #40
0
    def observe(self, dict_in):
        """
        Loads observation model parameters into a dictionary, 
        performs the forward model and provides an initial solution.

        Args:
        dict_in (dict): Dictionary which will be overwritten with 
        all of the observation model parameters, forward model 
        observation 'y', and initial estimate 'x_0'.
        """
        warnings.simplefilter("ignore", np.ComplexWarning)
        #########################################
        #fetch observation model parameters here#
        #########################################

        if (self.str_type[:11] == 'convolution'
                or self.str_type == 'compressed_sensing'):
            wrf = self.get_val('wienerfactor', True)
            str_domain = self.get_val('domain', False)
            noise_pars = defaultdict(int)  #build a dict to generate the noise
            noise_pars['seed'] = self.get_val('seed', True)
            noise_pars['variance'] = self.get_val('noisevariance', True)
            noise_pars['distribution'] = self.get_val('noisedistribution',
                                                      False)
            noise_pars['mean'] = self.get_val('noisemean', True)
            noise_pars['interval'] = self.get_val('noiseinterval',
                                                  True)  #uniform
            noise_pars['size'] = dict_in['x'].shape
            dict_in['noisevariance'] = noise_pars['variance']

            if self.str_type == 'compressed_sensing':
                noise_pars['complex_noise'] = 1
            if dict_in['noisevariance'] > 0:
                dict_in['n'] = noise_gen(noise_pars)
            else:
                dict_in['n'] = 0

        elif self.str_type == 'classification':
            #partition the classification dataset into an 'observed' training set
            #and an unobserved evaluation/test set, and generate features
            dict_in['x_train'] = {}
            dict_in['x_test'] = {}
            dict_in['y_label'] = {}
            dict_in['x_feature'] = {}
            dict_in['n_training_samples'] = 0
            dict_in['n_testing_samples'] = 0
            shuffle = self.get_val('shuffle', True)
            if shuffle:
                shuffleseed = self.get_val('shuffleseed', True)
            training_proportion = self.get_val('trainingproportion', True)
            classes = dict_in['x'].keys()
            #partition and generate numeric class labels
            for _class_index, _class in enumerate(classes):
                class_size = len(dict_in['x'][_class])
                training_size = int(training_proportion * class_size)
                dict_in['n_training_samples'] += training_size
                dict_in['n_testing_samples'] += class_size - training_size
                if shuffle:
                    np.random.seed(shuffleseed)
                    indices = np.random.permutation(class_size)
                else:
                    indices = np.array(range(class_size), dtype='uint16')
                dict_in['x_train'][_class] = indices[:training_size]
                dict_in['x_test'][_class] = indices[training_size:]
                dict_in['y_label'][_class] = _class_index
        else:
            raise ValueError('unsupported observation model')
        ################################################
        #compute the forward model and initial estimate#
        ################################################
        if self.str_type == 'convolution':
            H = self.Phi

            H.set_output_fourier(False)
            dict_in['Hx'] = H * dict_in['x']
            dict_in['y'] = dict_in['Hx'] + dict_in['n']
            #regularized Wiener filtering in Fourier domain
            H.set_output_fourier(True)
            dict_in['x_0'] = real(
                ifftn(~H * dict_in['y'] /
                      (H.get_spectrum_sq() + wrf * noise_pars['variance'])))
            # dict_in['x_0'] = real(ifftn(~H * dict_in['y'])) %testing only
            H.set_output_fourier(False)
            #compute bsnr
            self.compute_bsnr(dict_in, noise_pars)
        elif self.str_type == 'convolution_downsample':
            Phi = self.Phi
            #this order is important in the config file
            D = Phi.ls_ops[1]
            H = Phi.ls_ops[0]
            H.set_output_fourier(False)
            if self.get_val('spatialblur', True):
                dict_in['Phix'] = D * convolve(dict_in['x'], H.kernel, 'same')
                dict_in['Hxpn'] = convolve(dict_in['x'], H.kernel,
                                           'same') + dict_in['n']
            else:
                dict_in['Phix'] = Phi * dict_in['x']
                dict_in['Hxpn'] = H * dict_in['x'] + dict_in['n']
            dict_in['Hx'] = dict_in['Phix']
            #the version of y without downsampling
            dict_in['DHxpn'] = np.zeros((D * dict_in['Hxpn']).shape)
            if dict_in['n'].__class__.__name__ == 'ndarray':
                dict_in['n'] = D * dict_in['n']
            dict_in['y'] = dict_in['Hx'] + dict_in['n']
            DH = fftn(Phi * nd_impulse(dict_in['x'].shape))
            DHt = conj(DH)
            Hty = fftn(D * (~Phi * dict_in['y']))
            HtDtDH = np.real(DHt * DH)
            # dict_in['x_0'] = ~D*real(ifftn(Hty /
            #                                (HtDtDH +
            #                                 wrf * noise_pars['variance'])))
            dict_in['x_0'] = ~D * dict_in['y']
            #optional interpolation
            xdim = dict_in['x'].ndim
            xshp = dict_in['x'].shape
            if self.get_val('interpinitialsolution', True):
                if xdim == 2:
                    if self.get_val('useimresize', True):
                        interp_vals = imresize(
                            dict_in['y'],
                            tuple(D.ds_factor *
                                  np.asarray(dict_in['y'].shape)),
                            interp='bicubic')
                    else:
                        grids = np.mgrid[[
                            slice(0, xshp[j]) for j in xrange(xdim)
                        ]]
                        grids = tuple(
                            [grids[i] for i in xrange(grids.shape[0])])
                        sampled_coords = np.mgrid[[
                            slice(D.offset[j], xshp[j], D.ds_factor[j])
                            for j in xrange(xdim)
                        ]]
                        values = dict_in['x_0'][[
                            coord.flatten() for coord in sampled_coords
                        ]]
                        points = np.vstack([
                            sampled_coords[i, Ellipsis].flatten()
                            for i in xrange(sampled_coords.shape[0])
                        ]).transpose()  #pts to interp
                        interp_vals = griddata(points,
                                               values,
                                               grids,
                                               method='cubic',
                                               fill_value=0.0)
                else:
                    values = dict_in[
                        'y']  #we're not using blank values, different interpolation scheme..
                    dsfactors = np.asarray(
                        [int(D.ds_factor[j]) for j in xrange(values.ndim)])
                    valshpcorrect = (
                        np.asarray(values.shape) -
                        np.asarray(xshp, dtype='uint16') / dsfactors)
                    valshpcorrect = valshpcorrect / np.asarray(dsfactors,
                                                               dtype='float32')
                    interp_coords = iprod(*[
                        np.arange(0, values.shape[j] - valshpcorrect[j], 1.0 /
                                  D.ds_factor[j]) for j in xrange(values.ndim)
                    ])
                    interp_coords = np.array([el for el in interp_coords
                                              ]).transpose()
                    interp_vals = map_coordinates(values,
                                                  interp_coords,
                                                  order=3,
                                                  mode='nearest').reshape(xshp)
                    # interp_vals = map_coordinates(values,interp_coords,order=3,mode='nearest')
                    # cut off the edges
                    # if xdim == 2:
                    # interp_vals = interp_vals[0:xshp[0],0:xshp[1]]
                    # else:
                    interp_vals = interp_vals[0:xshp[0], 0:xshp[1], 0:xshp[2]]
                dict_in['x_0'] = interp_vals
            elif self.get_val('inputinitialsoln', False) != '':
                init_soln_inputsec = Input(
                    self.ps_parameters, self.get_val('inputinitialsoln',
                                                     False))
                dict_in['x_0'] = init_soln_inputsec.read({}, True)
            self.compute_bsnr(dict_in, noise_pars)

        elif self.str_type == 'convolution_poisson':
            dict_in['mp'] = self.get_val('maximumphotonspervoxel', True)
            dict_in['b'] = self.get_val('background', True)
            H = self.Phi
            if str_domain == 'fourier':
                H.set_output_fourier(False)  #return spatial domain object
                orig_shape = dict_in['x'].shape
                Hspec = np.zeros(orig_shape)
                dict_in['r'] = H * dict_in['x']
                k = dict_in['mp'] / nmax(dict_in['r'])
                dict_in['r'] = k * dict_in['r']
                #normalize the output image to have the same
                #maximum photon count as the ouput image
                dict_in['x'] = k * dict_in['x']
                dict_in['x'] = crop_center(
                    dict_in['x'], dict_in['r'].shape).astype('float32')
                #the spatial domain measurements, before photon counts
                dict_in['fb'] = dict_in['r'] + dict_in['b']
                #lambda of the poisson distn
                noise_pars['ary_mean'] = dict_in['fb']
                #specifying the poisson distn
                noise_distn2 = self.get_val('noisedistribution2', False)
                noise_pars['distribution'] = noise_distn2
                #generating quantized (uint16) poisson measurements
                # dict_in['y'] = (noise_gen(noise_pars)+dict_in['n']).astype('uint16').astype('int32')
                dict_in['y'] = noise_gen(noise_pars) + crop_center(
                    dict_in['n'], dict_in['fb'].shape)
                dict_in['y'][dict_in['y'] < 0] = 0
            elif str_domain == 'evaluation':  #are given the observation, which is stored in 'x'
                dict_in['y'] = dict_in.pop('x')
            else:
                raise Exception('domain not supported: ' + str_domain)
            dict_in['x_0'] = ((~H) * (dict_in['y'])).astype(dtype='float32')
            dict_in['y_padded'] = pad_center(dict_in['y'],
                                             dict_in['x_0'].shape)

        elif self.str_type == 'compressed_sensing':
            Fu = self.Phi
            dict_in['Hx'] = Fu * dict_in['x']
            dict_in['y'] = dict_in['Hx'] + dict_in['n']
            dict_in['x_0'] = (~Fu) * dict_in['y']
            dict_in['theta_0'] = angle(dict_in['x_0'])
            dict_in['theta_0'] = su.phase_unwrap(dict_in['theta_0'],
                                                 dict_in['dict_global_lims'],
                                                 dict_in['ls_local_lim_secs'])
            dict_in['magnitude_0'] = nabs(dict_in['x_0'])
            if self.get_val('maskinitialsoln', True):
                dict_in['theta_0'] *= dict_in['mask']
                dict_in['magnitude_0'] *= dict_in['mask']
            dict_in['x_0'] = dict_in['magnitude_0'] * exp(
                1j * dict_in['theta_0'])
            self.compute_bsnr(dict_in, noise_pars)
        #store the wavelet domain version of the ground truth
        if np.iscomplexobj(dict_in['x']):
            dict_in['w'] = [
                self.W * dict_in['x'].real, self.W * dict_in['x'].imag
            ]
        else:
            dict_in['w'] = [self.W * dict_in['x']]