예제 #1
0
 def test_split(self):
     mt = MonitorTransforms('a,c', '1,2', delim=',')
     self.assertEqual(len(mt.pre), 2)
     self.assertEqual(len(mt.post), 2)
     mt = MonitorTransforms('a;c', 'exp(x);2.234')
     self.assertEqual(len(mt.pre), 2)
     self.assertEqual(len(mt.post), 2)
예제 #2
0
 def test_post(self):
     state = numpy.tile(numpy.r_[:4], (2, 1)).reshape((2, -1, 1))
     state[1] *= 2
     # check expr eval correct
     mt = MonitorTransforms('0;0', 'mon;')
     _, out = mt.apply_post((0.0, state))
     self.assertEqual(3, out.flat[3])
     self.assertEqual(6, out.flat[7])
     mt = MonitorTransforms('0;0', 'mon;mon**2-1')
     _, out = mt.apply_post((0.0, state))
     self.assertEqual(3, out.flat[3])
     self.assertEqual(35, out.flat[7])
     # check correct shape
     n_expr = numpy.random.randint(5, 10)
     state = numpy.tile(numpy.r_[:4], (n_expr, 1)).reshape((n_expr, -1, 1))
     post_expr = ';'.join([str(i) for i in range(n_expr)])
     mt = MonitorTransforms('0', post_expr)
     _, out = mt.apply_post((0.0, state))
     self.assertEqual(n_expr, out.shape[0])
예제 #3
0
 def test_pre(self):
     state = numpy.r_[:4].reshape((1, -1, 1))
     # check expr correctly evaluated
     mt = MonitorTransforms('x0**2', '')
     out = mt.apply_pre(state)
     self.assertEqual(out[0, -1, 0], 9)
     # check correct shape
     n_expr = numpy.random.randint(5, 10)
     pre_expr = ';'.join([str(i) for i in range(n_expr)])
     mt = MonitorTransforms(pre_expr, '')
     out = mt.apply_pre(state)
     self.assertEqual(n_expr, out.shape[0])
예제 #4
0
 def _fail_noop_pre(self):
     MonitorTransforms(';;', ';;')
예제 #5
0
 def test_noop_post(self):
     mt = MonitorTransforms('a;b;c', '2.34*(pre+1.5);;')
     self.assertEqual(len(mt.post), 3)
예제 #6
0
 def _syntax_fail(self, pre, post):
     MonitorTransforms(pre, post)
예제 #7
0
 def test_post_1(self):
     mt = MonitorTransforms('a;b', 'c')
     self.assertEqual(1, len(mt.post))
예제 #8
0
 def test_pre_1(self):
     mt = MonitorTransforms('a', 'b;c')
     self.assertEqual(1, len(mt.pre))
예제 #9
0
 def test_post_default_expands(self):
     mt = MonitorTransforms('V;W', '')
     self.assertEqual(2, len(mt.post))
 def test_syntax(self):
     mt = MonitorTransforms('a+b/c*f(a,b)', '23')
     self.assertRaises(SyntaxError, self._syntax_fail)
 def _syntax_fail(self):
     MonitorTransforms('a=3', '23.234')
 def _shape_fail(self):
     MonitorTransforms('1,2,3', '2,3', delim=',')