Esempio n. 1
0
 def test_uniform(self):
     " BufferDict.uniform "
     b = BufferDict()
     BufferDict.uniform('f', 0., 1.)
     for fw, w in [(0, 0.5), (-1., 0.15865525393145707), (1, 1 - 0.15865525393145707)]:
         b['f(s)'] = fw 
         fmt = '{:.6f}'
         self.assertEqual(fmt.format(b['s']), fmt.format(w))
         b['f(a)'] = 4 * [fw]
         self.assertEqual(str(b['a']), str(np.array(4 * [w])))
     BufferDict.del_distribution('f')
Esempio n. 2
0
    def test_extension_mapping(self):
        " BufferDict extension and mapping properties  "
        p = BufferDict()
        p['a'] = 1.
        p['b'] = [2., 3.]
        p['log(c)'] = 0.
        p['sqrt(d)'] = [5., 6.]
        p['erfinv(e)'] = [[33.]]
        p['f(w)'] = BufferDict.uniform('f', 2., 3.).mean
        newp = BufferDict(p)
        for i in range(2):
            for k in p:
                assert np.all(p[k] == newp[k])
            assert newp['c'] == np.exp(newp['log(c)'])
            assert np.all(newp['d'] == np.square(newp['sqrt(d)']))
            assert np.all(newp['e'] == gv.erf(newp['erfinv(e)']))
            assert np.all(p.buf == newp.buf)
            p.buf[:-1] = [10., 20., 30., 1., 2., 3., 4.]
            newp.buf = np.array(p.buf.tolist())
        self.assertEqual(
            gv.get_dictkeys(p, ['c', 'a', 'log(c)', 'e', 'd', 'w', 'f(w)']),
            ['log(c)', 'a', 'log(c)', 'erfinv(e)', 'sqrt(d)', 'f(w)', 'f(w)']
            )
        self.assertEqual(
            [gv.dictkey(p, k)  for k in [
                'c', 'a', 'log(c)', 'e', 'd'
                ]],
            ['log(c)', 'a', 'log(c)', 'erfinv(e)', 'sqrt(d)']
            )
        self.assertTrue(gv.BufferDict.has_dictkey(p, 'a'))
        self.assertTrue(gv.BufferDict.has_dictkey(p, 'b'))
        self.assertTrue(gv.BufferDict.has_dictkey(p, 'c'))
        self.assertTrue(gv.BufferDict.has_dictkey(p, 'd'))
        self.assertTrue(gv.BufferDict.has_dictkey(p, 'e'))
        self.assertTrue(gv.BufferDict.has_dictkey(p, 'log(c)'))
        self.assertTrue(gv.BufferDict.has_dictkey(p, 'sqrt(d)'))
        self.assertTrue(gv.BufferDict.has_dictkey(p, 'erfinv(e)'))
        self.assertTrue(not gv.BufferDict.has_dictkey(p, 'log(a)'))
        self.assertTrue(not gv.BufferDict.has_dictkey(p, 'sqrt(b)'))
        self.assertEqual(list(p), ['a', 'b', 'log(c)', 'sqrt(d)', 'erfinv(e)', 'f(w)'])
        np.testing.assert_equal(
            (list(p.values())),
            ([10.0, [20., 30.], 1.0, [2., 3.], [[4.]], 0.])
            )
        self.assertEqual(p.get('c'), p['c'])

        # tracking?
        self.assertAlmostEqual(p['c'], np.exp(1))
        self.assertAlmostEqual(p['log(c)'], 1.)
        p['log(c)'] = 2.
        self.assertAlmostEqual(p['c'], np.exp(2))
        self.assertAlmostEqual(p['log(c)'], 2.)
        p['a'] = 12.
        self.assertAlmostEqual(p['c'], np.exp(2))
        self.assertAlmostEqual(p['log(c)'], 2.)
        self.assertEqual(
            list(p),
            ['a', 'b', 'log(c)', 'sqrt(d)', 'erfinv(e)', 'f(w)'],
            )

        # the rest is not so important
        # trim redundant keys
        oldp = trim_redundant_keys(newp)
        assert 'c' not in oldp
        assert 'd' not in oldp
        assert np.all(oldp.buf == newp.buf)

        # nonredundant keys
        # assert set(nonredundant_keys(newp.keys())) == set(p.keys())
        self.assertEqual(set(nonredundant_keys(newp.keys())), set(p.keys()))
        # stripkey
        for ks, f, k in [
            ('aa', np.exp, 'log(aa)'),
            ('aa', np.square, 'sqrt(aa)'),
            ]:
            self.assertEqual((ks,f), gv._bufferdict._stripkey(k))

        # addparentheses
        pvar = BufferDict()
        pvar['a'] = p['a']
        pvar['b'] = p['b']
        pvar['logc'] = p['log(c)']
        pvar['sqrtd'] = p['sqrt(d)']
        pvar['erfinv(e)'] = p['erfinv(e)']
        pvar['f(w)'] = p['f(w)']
        pvar = add_parameter_parentheses(pvar)
        for k in p:
            assert k in pvar
            assert np.all(p[k] == pvar[k])
        for k in pvar:
            assert k in p
        pvar = add_parameter_parentheses(pvar)
        for k in p:
            assert k in pvar
            assert np.all(p[k] == pvar[k])
        for k in pvar:
            assert k in p
        pvar['log(c(23))'] = 1.2
        pvar = BufferDict(pvar)
        assert 'c(23)' not in pvar
        assert 'log(c(23))' in pvar
        self.assertAlmostEqual(gv.exp(pvar['log(c(23))']), pvar['c(23)'])
        BufferDict.del_distribution('f')