コード例 #1
0
def test_nsi_params():
    """Unit tests for subclasses of `NSIParams`."""
    # TODO: these have to be extended
    rand = np.random.RandomState(0)
    std_nsi = StdNSIParams()
    try:
        # cannot accept a sequence
        std_nsi.eps_ee = [rand.rand()]
    except TypeError:
        pass

    try:
        # must be real
        std_nsi.eps_ee = rand.rand() * 1.j
    except TypeError:
        pass

    try:
        # unphysical negative magnitude for nonzero phase
        std_nsi.eps_mutau = ((rand.rand() - 1.0), 0.1)
    except ValueError:
        pass

    std_nsi.eps_ee = 0.5
    std_nsi.eps_mumu = 0.5
    std_nsi.eps_tautau = 0.5
    if not np.allclose(
        std_nsi.eps_matrix, np.zeros((3, 3), dtype=CTYPE), **ALLCLOSE_KW
    ):
        raise ValueError("NSI coupling matrix should be identically zero!")

    vac_like_nsi = VacuumLikeNSIParams()
    vac_like_nsi.eps_scale = rand.rand() * 10.
    assert recursiveEquality(vac_like_nsi.eps_ee, vac_like_nsi.eps_scale - 1.0)
コード例 #2
0
ファイル: param.py プロジェクト: terliuk/pisa
    def validate_binning(self):
        # Right now this can only deal with 2D energy / coszenith binning
        # Code can probably be generalised, but right now is not
        if set(self.input_binning.names) != set(['true_coszen', 'true_energy'
                                                 ]):
            raise ValueError(
                "Input binning must be 2D true energy / coszenith binning. "
                "Got %s." % (self.input_binning.names))

        assert set(self.input_binning.basename_binning.names) == \
               set(self.output_binning.basename_binning.names), \
               "input and output binning must both be 2D in energy / coszenith!"

        if self.coszen_flipback is None:
            raise ValueError(
                "coszen_flipback should be set to True or False since"
                " coszen is in your binning.")

        if self.coszen_flipback:
            coszen_output_binning = self.output_binning.basename_binning[
                'coszen']

            if not coszen_output_binning.is_lin:
                raise ValueError(
                    "coszen_flipback is set to True but zenith output"
                    " binning is not linear - incompatible settings!")
            coszen_step_out = (coszen_output_binning.range.magnitude /
                               coszen_output_binning.size)

            if not recursiveEquality(int(1 / coszen_step_out),
                                     1 / coszen_step_out):
                raise ValueError(
                    "coszen_flipback requires an integer number of"
                    " coszen output binning steps to fit into a range"
                    " of integer length.")
コード例 #3
0
def test_histogram():
    """Unit tests for `histogram` function.

    Correctness is defined as matching the histogram produced by
    numpy.histogramdd.
    """
    all_num_bins = [2, 3, 4]
    n_evts = 10000
    rand = np.random.RandomState(seed=0)

    weights = rand.rand(n_evts).astype(FTYPE)
    binning = []
    sample = []
    for num_dims, num_bins in enumerate(all_num_bins, start=1):
        binning.append(
            OneDimBinning(
                name=f'dim{num_dims - 1}',
                num_bins=num_bins,
                is_lin=True,
                domain=[0, num_bins],
            ))

        s = rand.rand(n_evts).astype(FTYPE) * num_bins
        sample.append(s)

        bin_edges = [b.edge_magnitudes for b in binning]
        test = histogram(sample,
                         weights,
                         MultiDimBinning(binning),
                         averaged=False)
        ref, _ = np.histogramdd(sample=sample, bins=bin_edges, weights=weights)
        ref = ref.astype(FTYPE).ravel()
        assert recursiveEquality(test, ref), f'\ntest:\n{test}\n\nref:\n{ref}'

        test_avg = histogram(sample,
                             weights,
                             MultiDimBinning(binning),
                             averaged=True)
        ref_counts, _ = np.histogramdd(sample=sample,
                                       bins=bin_edges,
                                       weights=None)
        ref_counts = ref_counts.astype(FTYPE).ravel()
        ref_avg = (ref / ref_counts).astype(FTYPE)
        assert recursiveEquality(test_avg, ref_avg), \
                f'\ntest_avg:\n{test_avg}\n\nref_avg:\n{ref_avg}'

    logging.info('<< PASS : test_histogram >>')
コード例 #4
0
ファイル: make_toy_events.py プロジェクト: terliuk/pisa
def populate_pid(mc_events,
                 param_source,
                 cut_val=0,
                 random_state=None,
                 dist='discrete',
                 **dist_kwargs):
    """Construct a 'pid' field within the `mc_events` object.

    Parameters
    ----------
    mc_events : pisa.core.Events
    param_source
    cut_val
    random_state
    dist

    """
    random_state = get_random_state(random_state)
    logging.info('  Classifying events as tracks or cascades')

    dist_allowed = ('discrete', 'normal')
    assert dist in dist_allowed

    pid_param = load_pid_energy_param(param_source)

    for flavint in mc_events.flavints:
        pid_funcs = None
        for flavintgroup, funcs in pid_param.iteritems():
            if flavint in flavintgroup:
                pid_funcs = funcs
        if pid_funcs is None:
            raise ValueError('Could not find pid param for %s' % flavint)

        reco_energies = mc_events[flavint]['reco_energy']
        track_pid_probs = pid_funcs['track'](reco_energies)
        cascade_pid_probs = pid_funcs['cascade'](reco_energies)
        assert np.all(np.isclose(track_pid_probs + cascade_pid_probs, 1))
        if dist == 'discrete':
            logging.debug('  Drawing discrete PID values')
            rands = random_state.uniform(size=len(reco_energies))
            pid_vals = np.where(rands <= track_pid_probs, cut_val + 1,
                                cut_val - 1)
        elif dist == 'normal':
            logging.debug('  Drawing normally distributed PID values')
            # cascades fall below `cut_val`, tracks above
            locs_shifted = cut_val - norm.ppf(cascade_pid_probs, **dist_kwargs)
            assert recursiveEquality(
                norm(loc=locs_shifted, **dist_kwargs).cdf(cut_val),
                cascade_pid_probs)
            rv = norm(loc=locs_shifted, **dist_kwargs)
            # size is important in the following, as otherwise all samples are
            # 100% correlated
            pid_vals = rv.rvs(size=len(reco_energies))
        mc_events[flavint]['pid'] = pid_vals.astype(FTYPE)

    return mc_events
コード例 #5
0
ファイル: param.py プロジェクト: terliuk/pisa
    def extend_binning_for_coszen(self, ext_low=-3., ext_high=+3.):
        """
        Check whether `coszen_flipback` can be applied to the stage's
        coszen output binning and return an extended binning spanning [-3, +3]
        if that is the case.
        """
        logging.trace("Preparing binning for flipback of reco kernel at"
                      " coszen boundaries of physical range.")

        cz_edges_out = self.output_binning['reco_coszen'].bin_edges.magnitude
        coszen_range = self.output_binning['reco_coszen'].range.magnitude
        n_cz_out = self.output_binning['reco_coszen'].size
        coszen_step = coszen_range / n_cz_out
        # we need to check for possible contributions from (-3, -1) and
        # (1, 3) in coszen
        assert ext_high > ext_low
        ext_range = ext_high - ext_low
        extended = np.linspace(ext_low, ext_high,
                               int(ext_range / coszen_step) + 1)

        # We cannot flipback if we don't have -1 & +1 as (part of extended)
        # bin edges. This could happen if 1 is a multiple of the output bin
        # size, but the original edges themselves are not a multiple of that
        # size.
        for bound in (-1., +1.):
            comp = [recursiveEquality(bound, e) for e in extended]
            assert np.any(comp)

        # Perform one final check: original edges subset of extended ones?
        for coszen in cz_edges_out:
            comp = [recursiveEquality(coszen, e) for e in extended]
            assert np.any(comp)

        # Binning seems fine - we can proceed
        ext_cent = (extended[1:] + extended[:-1]) / 2.
        flipback_mask = ((ext_cent < -1.) | (ext_cent > +1.))
        keep = np.where((ext_cent > cz_edges_out[0])
                        & (ext_cent < cz_edges_out[-1]))[0]
        cz_edges_out = extended
        logging.trace("  -> temporary coszen bin edges:\n%s" % cz_edges_out)

        return cz_edges_out, flipback_mask, keep
コード例 #6
0
ファイル: param.py プロジェクト: terliuk/pisa
 def check_reco_dist_consistency(self, dist_list):
     """Enforces correct normalisation of resolution functions."""
     logging.trace(
         " Verifying correct normalisation of resolution function.")
     # Obtain list of all distributions. The sum of their relative weights
     # should yield 1.
     frac_sum = np.zeros_like(dist_list[0]['fraction'])
     for dist_dict in dist_list:
         frac_sum += dist_dict['fraction']
     if not recursiveEquality(frac_sum, np.ones_like(frac_sum)):
         err_msg = ("Total normalisation of resolution function is off"
                    " (fractions do not add up to 1).")
         raise ValueError(err_msg)
     return True
コード例 #7
0
ファイル: jsons.py プロジェクト: terliuk/pisa
def test_NumpyEncoderDecoder():
    """Unit tests for NumpyEncoder and NumpyDecoder"""
    from shutil import rmtree
    import sys
    from pisa.utils.comparisons import recursiveEquality

    nda1 = np.array([
        -np.inf,
        np.nan,
        np.inf,
        -1,
        0,
        1,
    ])
    temp_dir = tempfile.mkdtemp()
    try:
        fname = os.path.join(temp_dir, 'nda1.json')
        to_json(nda1, fname)
        fname2 = os.path.join(temp_dir, 'nda1.json.bz2')
        to_json(nda1, fname2)
        for fn in [fname, fname2]:
            nda2 = from_json(fn)
            assert np.allclose(nda2, nda1, rtol=1e-12, atol=0, equal_nan=True), \
                    'nda1=\n%s\nnda2=\n%s\nsee file: %s' %(nda1, nda2, fn)
        d1 = {'nda1': nda1}
        fname = os.path.join(temp_dir, 'd1.json')
        fname2 = os.path.join(temp_dir, 'd1.json.bz2')
        fname3 = os.path.join(temp_dir, 'd1.json.xor')
        to_json(d1, fname)
        to_json(d1, fname2)
        to_json(d1, fname3)
        for fn in [fname, fname2, fname3]:
            d2 = from_json(fn)
            assert recursiveEquality(d2, d1), \
                    'd1=\n%s\nd2=\n%s\nsee file: %s' %(d1, d2, fn)
    finally:
        rmtree(temp_dir)

    sys.stdout.write('<< PASS : test_NumpyEncoderDecoder >>\n')
コード例 #8
0
def test_CrossSections(outdir=None):
    """Unit tests for CrossSections class"""
    from shutil import rmtree
    from tempfile import mkdtemp

    remove_dir = False
    if outdir is None:
        remove_dir = True
        outdir = mkdtemp()

    try:
        # "Standard" location of cross sections file in PISA; retrieve 2.6.4 for
        # testing purposes
        pisa_xs_file = 'cross_sections/cross_sections.json'
        xs = CrossSections(ver='genie_2.6.4', xsec=pisa_xs_file)

        # Location of the root file to use (not included in PISA at the moment)
        test_dir = expand(os.path.join('/tmp', 'pisa_tests', 'cross_sections'))
        #root_xs_file = os.path.join(test_dir, 'genie_2.6.4_simplified.root')
        root_xs_file = find_resource(os.path.join(
            #'tests', 'data', 'xsec', 'genie_2.6.4_simplified.root'
            'cross_sections', 'genie_xsec_H2O.root'
        ))

        # Make sure that the XS newly-imported from ROOT match those stored in
        # PISA
        if os.path.isfile(root_xs_file):
            xs_from_root = CrossSections.new_from_root(root_xs_file,
                                                       ver='genie_2.6.4')
            logging.info('Found and loaded ROOT source cross sections file %s',
                         root_xs_file)
            #assert xs_from_root.allclose(xs, rtol=1e-7)

        # Check XS ratio for numu_cc to numu_cc + numu_nc (user must inspect)
        kg0 = NuFlavIntGroup('numu_cc')
        kg1 = NuFlavIntGroup('numu_nc')
        logging.info(
            r'\int_1^80 xs(numu_cc) E^{-1} dE = %e',
            xs.get_xs_ratio_integral(kg0, None, e_range=[1, 80], gamma=1)
        )
        logging.info(
            '(int E^{-gamma} * (sigma_numu_cc)/int(sigma_(numu_cc+numu_nc)) dE)'
            ' / (int E^{-gamma} dE) = %e',
            xs.get_xs_ratio_integral(kg0, kg0+kg1, e_range=[1, 80], gamma=1,
                                     average=True)
        )
        # Check that XS ratio for numu_cc+numu_nc to the same is 1.0
        int_val = xs.get_xs_ratio_integral(kg0+kg1, kg0+kg1, e_range=[1, 80],
                                           gamma=1, average=True)
        if not recursiveEquality(int_val, 1):
            raise ValueError('Integral of nc + cc should be 1.0; get %e'
                             ' instead.' % int_val)

        # Check via plot that the

        # Plot all cross sections stored in PISA xs file
        try:
            alldata = from_file(pisa_xs_file)
            xs_versions = alldata.keys()
            for ver in xs_versions:
                xs = CrossSections(ver=ver, xsec=pisa_xs_file)
                xs.plot(save=os.path.join(
                    outdir, 'pisa_' + ver + '_nuxCCNC_H2O_cross_sections.pdf'
                ))
        except ImportError as exc:
            logging.debug('Could not plot; possible that matplotlib not'
                          'installed. ImportError: %s', exc)

    finally:
        if remove_dir:
            rmtree(outdir)
コード例 #9
0
ファイル: gaussians.py プロジェクト: terliuk/pisa
def test_gaussians():
    """Test `gaussians` function"""
    n_gaus = [1, 10, 100, 1000, 10000]
    n_eval = int(1e4)

    x = np.linspace(-20, 20, n_eval)
    np.random.seed(0)
    mu_sigma_weight_sets = [(np.linspace(-50, 50,
                                         n), np.linspace(0.5, 100,
                                                         n), np.random.rand(n))
                            for n in n_gaus]

    timings = OrderedDict()
    for impl in GAUS_IMPLEMENTATIONS:
        timings[impl] = []

    for mus, sigmas, weights in mu_sigma_weight_sets:
        if not isinstance(mus, Iterable):
            mus = [mus]
            sigmas = [sigmas]
            weights = [weights]
        ref_unw = np.sum(
            [stats.norm.pdf(x, loc=m, scale=s) for m, s in zip(mus, sigmas)],
            axis=0) / len(mus)
        ref_w = np.sum([
            stats.norm.pdf(x, loc=m, scale=s) * w
            for m, s, w in zip(mus, sigmas, weights)
        ],
                       axis=0) / np.sum(weights)
        for impl in GAUS_IMPLEMENTATIONS:
            t0 = time()
            test_unw = gaussians(x,
                                 mu=mus,
                                 sigma=sigmas,
                                 weights=None,
                                 implementation=impl)
            dt_unw = time() - t0
            t0 = time()
            test_w = gaussians(x,
                               mu=mus,
                               sigma=sigmas,
                               weights=weights,
                               implementation=impl)
            dt_w = time() - t0
            timings[impl].append(
                (np.round(dt_unw * 1000,
                          decimals=3), np.round(dt_w * 1000, decimals=3)))
            err_msgs = []
            if not recursiveEquality(test_unw, ref_unw):
                err_msgs.append(
                    'BAD RESULT (unweighted), n_gaus=%d, implementation='
                    '"%s", max. abs. fract. diff.: %s' %
                    (len(mus), impl, np.max(np.abs((test_unw / ref_unw - 1)))))
            if not recursiveEquality(test_w, ref_w):
                err_msgs.append(
                    'BAD RESULT (weighted), n_gaus=%d, implementation="%s"'
                    ', max. abs. fract. diff.: %s' %
                    (len(mus), impl, np.max(np.abs((test_w / ref_w - 1)))))
            if err_msgs:
                for err_msg in err_msgs:
                    logging.error(err_msg)
                raise ValueError('\n'.join(err_msgs))

    tprofile.debug(
        'gaussians() timings (unweighted) (Note:OMP_NUM_THREADS=%d; evaluated'
        ' at %.0e points)', OMP_NUM_THREADS, n_eval)
    timings_str = '  '.join([format(t, '10d') for t in n_gaus])
    tprofile.debug(' ' * 30 + 'Number of gaussians'.center(59))
    tprofile.debug('         %15s       %s', 'impl.', timings_str)
    timings_str = '  '.join(['-' * 10 for t in n_gaus])
    tprofile.debug('         %15s       %s', '-' * 15, timings_str)
    for impl in GAUS_IMPLEMENTATIONS:
        # only report timings for unweighted case
        timings_str = '  '.join([format(t[0], '10.3f') for t in timings[impl]])
        tprofile.debug('Timings, %15s (ms): %s', impl, timings_str)
    logging.info('<< PASS : test_gaussians >>')
コード例 #10
0
ファイル: events.py プロジェクト: terliuk/pisa
 def data_eq(self, other):
     """Test whether the data for this object matche that of `other`"""
     return recursiveEquality(self, other)
コード例 #11
0
ファイル: events.py プロジェクト: terliuk/pisa
 def meta_eq(self, other):
     """Test whether the metadata for this object matches that of `other`"""
     return recursiveEquality(self.metadata, other.metadata)
コード例 #12
0
ファイル: hdf.py プロジェクト: thehrh/pisa-1
def test_hdf():
    """Unit tests for hdf module"""
    from shutil import rmtree
    from tempfile import mkdtemp

    data = OrderedDict([
        ('top', OrderedDict([
            ('secondlvl1', OrderedDict([
                ('thirdlvl11', np.linspace(1, 100, 10000).astype(np.float64)),
                ('thirdlvl12', b"this is a string"),
                ('thirdlvl13', b"this is another string"),
                ('thirdlvl14', 1),
                ('thirdlvl15', 1.1),
                ('thirdlvl16', np.float32(1.1)),
                ('thirdlvl17', np.float64(1.1)),
                ('thirdlvl18', np.int8(1)),
                ('thirdlvl19', np.int16(1)),
                ('thirdlvl110', np.int32(1)),
                ('thirdlvl111', np.int64(1)),
                ('thirdlvl112', np.uint8(1)),
                ('thirdlvl113', np.uint16(1)),
                ('thirdlvl114', np.uint32(1)),
                ('thirdlvl115', np.uint64(1)),
            ])),
            ('secondlvl2', OrderedDict([
                ('thirdlvl21', np.linspace(1, 100, 10000).astype(np.float32)),
                ('thirdlvl22', b"this is a string"),
                ('thirdlvl23', b"this is another string"),
            ])),
            ('secondlvl3', OrderedDict([
                ('thirdlvl31', np.array(range(1000)).astype(np.int)),
                ('thirdlvl32', b"this is a string"),
            ])),
            ('secondlvl4', OrderedDict([
                ('thirdlvl41', np.linspace(1, 100, 10000)),
                ('thirdlvl42', b"this is a string"),
            ])),
            ('secondlvl5', OrderedDict([
                ('thirdlvl51', np.linspace(1, 100, 10000)),
                ('thirdlvl52', b"this is a string"),
            ])),
            ('secondlvl6', OrderedDict([
                ('thirdlvl61', np.linspace(100, 1000, 10000)),
                ('thirdlvl62', b"this is a string"),
            ])),
        ]))
    ])

    temp_dir = mkdtemp()
    try:
        fpath = os.path.join(temp_dir, 'to_hdf_noattrs.hdf5')
        to_hdf(data, fpath, overwrite=True, warn=False)
        loaded_data1 = from_hdf(fpath)
        assert data.keys() == loaded_data1.keys()
        assert recursiveEquality(data, loaded_data1), \
                str(data) + "\n" + str(loaded_data1)

        attrs = OrderedDict([
            ('float', 9.98237),
            ('float32', np.float32(1.)),
            ('float64', np.float64(1.)),
            ('pi', np.float64(np.pi)),

            ('string', "string attribute!"),

            ('int', 1),
            ('int8', np.int8(1)),
            ('int16', np.int16(1)),
            ('int32', np.int32(1)),
            ('int64', np.int64(1)),

            ('uint8', np.uint8(1)),
            ('uint16', np.uint16(1)),
            ('uint32', np.uint32(1)),
            ('uint64', np.uint64(1)),

            ('bool', True),
            ('bool8', np.bool8(True)),
            ('bool_', np.bool_(True)),
        ])

        attr_type_checkers = {
            "float": lambda x: isinstance(x, float),
            "float32": lambda x: x.dtype == np.float32,
            "float64": lambda x: x.dtype == np.float64,
            "pi": lambda x: x.dtype == np.float64,

            "string": lambda x: isinstance(x, string_types),

            "int": lambda x: isinstance(x, int),
            "int8": lambda x: x.dtype == np.int8,
            "int16": lambda x: x.dtype == np.int16,
            "int32": lambda x: x.dtype == np.int32,
            "int64": lambda x: x.dtype == np.int64,

            "uint8": lambda x: x.dtype == np.uint8,
            "uint16": lambda x: x.dtype == np.uint16,
            "uint32": lambda x: x.dtype == np.uint32,
            "uint64": lambda x: x.dtype == np.uint64,

            "bool": lambda x: isinstance(x, bool),
            "bool8": lambda x: x.dtype == np.bool8,
            "bool_": lambda x: x.dtype == np.bool_,
        }

        fpath = os.path.join(temp_dir, 'to_hdf_withattrs.hdf5')
        to_hdf(data, fpath, attrs=attrs, overwrite=True, warn=False)
        loaded_data2 = from_hdf(fpath)
        loaded_attrs = loaded_data2.attrs
        assert data.keys() == loaded_data2.keys()
        assert attrs.keys() == loaded_attrs.keys(), \
                '\n' + str(attrs.keys()) + '\n' + str(loaded_attrs.keys())
        assert recursiveEquality(data, loaded_data2)
        assert recursiveEquality(attrs, loaded_attrs)

        for key, val in attrs.items():
            tgt_type_checker = attr_type_checkers[key]
            assert tgt_type_checker(val), \
                    "key '%s': val '%s' is type '%s'" % \
                    (key, val, type(loaded_attrs[key]))
    finally:
        rmtree(temp_dir)

    logging.info('<< PASS : test_hdf >>')
コード例 #13
0
ファイル: transform.py プロジェクト: terliuk/pisa
 def __eq__(self, other):
     if not isinstance(other, BinnedTensorTransform):
         return False
     return recursiveEquality(self.hashable_state, other.hashable_state)
コード例 #14
0
def test_to_json_from_json():
    """Unit tests for writing various types of objects to and reading from JSON
    files (including bz2-compressed and xor-scrambled files)"""
    # pylint: disable=unused-variable
    from shutil import rmtree
    import sys
    from pisa.utils.comparisons import recursiveEquality

    proto_float_array = np.array([-np.inf, np.nan, np.inf, -1.1, 0.0, 1.1],
                                 dtype=np.float64)
    proto_int_array = np.array([-2, -1, 0, 1, 2], dtype=np.int64)
    proto_str_array = np.array(['a', 'ab', 'abc', '', ' '], dtype=str)

    floating_types = [float] + sorted(
        set(t for _, t in np.typeDict.items() if issubclass(t, np.floating)),
        key=str,
    )
    integer_types = [int] + sorted(
        set(t for _, t in np.typeDict.items() if issubclass(t, np.integer)),
        key=str,
    )

    test_info = [
        dict(
            proto_array=proto_float_array,
            dtypes=floating_types,
        ),
        dict(
            proto_array=proto_int_array,
            dtypes=integer_types,
        ),
        # TODO: strings currently do not work
        #dict(
        #    proto_array=proto_str_array,
        #    dtypes=[str, np.str0, np.str_, np.string_],
        #),
    ]

    test_data = OrderedDict()
    for info in test_info:
        proto_array = info['proto_array']
        for dtype in info['dtypes']:
            typed_array = proto_array.astype(dtype)
            s_dtype = str(np.dtype(dtype))
            test_data["array_" + s_dtype] = typed_array
            test_data["scalar_" + s_dtype] = dtype(typed_array[0])

    temp_dir = tempfile.mkdtemp()
    try:
        for name, obj in test_data.items():
            # Test that the object can be written / read directly
            base_fname = os.path.join(temp_dir, name + '.json')
            for ext in ['', '.bz2', '.xor']:
                fname = base_fname + ext
                to_json(obj, fname)
                loaded_data = from_json(fname)
                if obj.dtype in floating_types:
                    assert np.allclose(
                        loaded_data, obj, rtol=1e-12, atol=0, equal_nan=True
                    ), '{}=\n{}\nloaded=\n{}\nsee file: {}'.format(
                        name, obj, loaded_data, fname)
                else:
                    assert np.all(loaded_data == obj), \
                        '{}=\n{}\nloaded_nda=\n{}\nsee file: {}'.format(
                            name, obj, loaded_data, fname
                        )

            # Test that the same object can be written / read as a value in a
            # dictionary
            orig = OrderedDict([(name, obj), (name + "x", obj)])
            base_fname = os.path.join(temp_dir, 'd.{}.json'.format(name))
            for ext in ['', '.bz2', '.xor']:
                fname = base_fname + ext
                to_json(orig, fname)
                loaded = from_json(fname)
                assert recursiveEquality(loaded, orig), \
                    'orig=\n{}\nloaded=\n{}\nsee file: {}'.format(
                        orig, loaded, fname
                    )
    finally:
        rmtree(temp_dir)

    logging.info('<< PASS : test_to_json_from_json >>')
コード例 #15
0
ファイル: prior.py プロジェクト: terliuk/pisa
 def __eq__(self, other):
     if not isinstance(other, self.__class__):
         return False
     return recursiveEquality(self.state, other.state)
コード例 #16
0
def test_hdf():
    """Unit tests for hdf module"""
    from shutil import rmtree
    from tempfile import mkdtemp

    data = OrderedDict([
        ('top', OrderedDict([
            ('secondlvl1', OrderedDict([
                ('thirdlvl11', np.linspace(1, 100, 10000)),
                ('thirdlvl12', "this is a string")
            ])),
            ('secondlvl2', OrderedDict([
                ('thirdlvl21', np.linspace(1, 100, 10000)),
                ('thirdlvl22', "this is a string")
            ])),
            ('secondlvl3', OrderedDict([
                ('thirdlvl31', np.linspace(1, 100, 10000)),
                ('thirdlvl32', "this is a string")
            ])),
            ('secondlvl4', OrderedDict([
                ('thirdlvl41', np.linspace(1, 100, 10000)),
                ('thirdlvl42', "this is a string")
            ])),
            ('secondlvl5', OrderedDict([
                ('thirdlvl51', np.linspace(1, 100, 10000)),
                ('thirdlvl52', "this is a string")
            ])),
            ('secondlvl6', OrderedDict([
                ('thirdlvl61', np.linspace(100, 1000, 10000)),
                ('thirdlvl62', "this is a string")
            ])),
        ]))
    ]) # yapf: disable

    temp_dir = mkdtemp()
    try:
        fpath = os.path.join(temp_dir, 'to_hdf_noattrs.hdf5')
        to_hdf(data, fpath, overwrite=True, warn=False)
        loaded_data1 = from_hdf(fpath)
        assert data.keys() == loaded_data1.keys()
        assert recursiveEquality(data, loaded_data1)

        attrs = OrderedDict([
            ('float1', 9.98237),
            ('float2', 1.),
            ('pi', np.pi),
            ('string', "string attribute!"),
            ('int', 1)
        ]) # yapf: disable
        fpath = os.path.join(temp_dir, 'to_hdf_withattrs.hdf5')
        to_hdf(data, fpath, attrs=attrs, overwrite=True, warn=False)
        loaded_data2, loaded_attrs = from_hdf(fpath, return_attrs=True)
        assert data.keys() == loaded_data2.keys()
        assert attrs.keys() == loaded_attrs.keys(), \
                '\n' + str(attrs.keys()) + '\n' + str(loaded_attrs.keys())
        assert recursiveEquality(data, loaded_data2)
        assert recursiveEquality(attrs, loaded_attrs)

        for k, v in attrs.items():
            tgt_type = type(attrs[k])
            assert isinstance(loaded_attrs[k], tgt_type), \
                    "key %s: val '%s' is type '%s' but should be '%s'" % \
                    (k, v, type(loaded_attrs[k]), tgt_type)
    finally:
        rmtree(temp_dir)

    logging.info('<< PASS : test_hdf >>')