def evaluate_expression(expression_string, grids, symbol): """Evaluates expression. Args: expression_string: String. The univariate expression, for example 'x * x + 1 / x'. grids: Numpy array with shape [num_grid_points], the points to evaluate expression. symbol: String. Symbol of variable in expression. Returns: Numpy array with shape [num_grid_points]. """ try: expression_on_grids = evaluators.numpy_array_eval( str(sympy.simplify(expression_string)), arguments={symbol: grids}) except SyntaxError as error: # NOTE(leeley): In some rare cases, after sympy.simplify(), # expression_string will contain symbols which can not be parsed, # for example 'zoo'. If this occurs, evaluate expression without # simplification. logging.warning(error) logging.warning( 'SyntaxError occurs after sympy.simplify(), ' 'evaluate %s directly without simplification.', expression_string) expression_on_grids = evaluators.numpy_array_eval( expression_string, arguments={symbol: grids}) if np.asarray(expression_on_grids).size == 1: expression_on_grids = expression_on_grids * np.ones_like(grids) return expression_on_grids
def test_number(self): # Integers. self.assertEqual(evaluators.numpy_array_eval('42'), 42) self.assertEqual(evaluators.numpy_array_eval('-42'), -42) # Floats. self.assertAlmostEqual(evaluators.numpy_array_eval('4.2'), 4.2) self.assertAlmostEqual(evaluators.numpy_array_eval('-4.2'), -4.2)
def test_malformed_string(self, string): with self.assertRaisesRegex(SyntaxError, 'Malformed string'): evaluators.numpy_array_eval(string)
def test_unknown_callable(self): with self.assertRaisesRegex(SyntaxError, 'Unknown callable: \'foo\''): evaluators.numpy_array_eval('foo( 4 )', callables={'bar': np.sin})
def test_unknown_argument(self): with self.assertRaisesRegex(SyntaxError, 'Unknown argument: \'x\''): evaluators.numpy_array_eval('x', arguments={'y': 2})
def test_callables_not_dict(self): with self.assertRaisesRegex(ValueError, 'Input callables expected to be a dict'): evaluators.numpy_array_eval('sqrt( 4 )', callables=[np.sqrt])
def test_arguments_not_dict(self): with self.assertRaisesRegex(ValueError, 'Input arguments expected to be a dict'): evaluators.numpy_array_eval('x', arguments=[42])
def test_one_zero_element_in_divisor_array_expression(self): arguments = {'x': 1., 'y': np.array([2., 0., 0.5])} np.testing.assert_allclose( evaluators.numpy_array_eval('x / y', arguments=arguments), np.array([0.5, 0., 2.]))
def test_callables(self, string, arguments, expected): np.testing.assert_allclose( evaluators.numpy_array_eval(string, arguments=arguments), expected)
def test_numpy_array_and_number_binary_operators(self, string, expected): arguments = {'x': np.array([1., 2., 3.]), 'y': 10.} np.testing.assert_allclose( evaluators.numpy_array_eval(string, arguments=arguments), expected)
def test_number_binary_operators_power(self, arguments, expected): self.assertAlmostEqual( evaluators.numpy_array_eval('x ** y', arguments=arguments), expected)
def test_number_binary_operators(self, string, expected): arguments = {'x': 1., 'y': 2.} self.assertAlmostEqual( evaluators.numpy_array_eval(string, arguments=arguments), expected)
def test_unary_operator(self, arguments, expected): np.testing.assert_allclose( evaluators.numpy_array_eval('-x', arguments=arguments), expected)