def setUp(self): self._test_x = np.arange(10) self._test_a = 2.0 self._test_b = -1.0 self._test_y = self.linear_model(self._test_x, self._test_a, self._test_b) self._model_function = HistModelFunction(self.linear_model) self._roundtrip_stringstream = IOStreamHandle(StringIO()) self._testfile_stringstream = IOStreamHandle( StringIO(TEST_MODEL_FUNCTION_HIST)) self._testfile_stringstream_with_formatter = IOStreamHandle( StringIO(TEST_MODEL_FUNCTION_HIST_WITH_FORMATTER)) self._roundtrip_streamreader = ModelFunctionYamlReader( self._roundtrip_stringstream) self._roundtrip_streamwriter = ModelFunctionYamlWriter( self._model_function, self._roundtrip_stringstream) self._testfile_streamreader = ModelFunctionYamlReader( self._testfile_stringstream) self._testfile_streamreader_with_formatter = ModelFunctionYamlReader( self._testfile_stringstream_with_formatter) self._testfile_stringstream_missing_keyword = IOStreamHandle( StringIO(TEST_MODEL_FUNCTION_HIST_MISSING_KEYWORD)) self._testfile_stringstream_extra_keyword = IOStreamHandle( StringIO(TEST_MODEL_FUNCTION_HIST_EXTRA_KEYWORD)) self._testfile_streamreader_missing_keyword = ModelFunctionYamlReader( self._testfile_stringstream_missing_keyword) self._testfile_streamreader_extra_keyword = ModelFunctionYamlReader( self._testfile_stringstream_extra_keyword)
class TestHistModelFunctionYamlRepresenter(unittest.TestCase): @staticmethod def linear_model(x, a, b): return a * x + b def setUp(self): self._test_x = np.arange(10) self._test_a = 2.0 self._test_b = -1.0 self._test_y = self.linear_model(self._test_x, self._test_a, self._test_b) self._model_function = HistModelFunction(self.linear_model) self._roundtrip_stringstream = IOStreamHandle(StringIO()) self._testfile_stringstream = IOStreamHandle( StringIO(TEST_MODEL_FUNCTION_HIST)) self._testfile_stringstream_with_formatter = IOStreamHandle( StringIO(TEST_MODEL_FUNCTION_HIST_WITH_FORMATTER)) self._roundtrip_streamreader = ModelFunctionYamlReader( self._roundtrip_stringstream) self._roundtrip_streamwriter = ModelFunctionYamlWriter( self._model_function, self._roundtrip_stringstream) self._testfile_streamreader = ModelFunctionYamlReader( self._testfile_stringstream) self._testfile_streamreader_with_formatter = ModelFunctionYamlReader( self._testfile_stringstream_with_formatter) self._testfile_stringstream_missing_keyword = IOStreamHandle( StringIO(TEST_MODEL_FUNCTION_HIST_MISSING_KEYWORD)) self._testfile_stringstream_extra_keyword = IOStreamHandle( StringIO(TEST_MODEL_FUNCTION_HIST_EXTRA_KEYWORD)) self._testfile_streamreader_missing_keyword = ModelFunctionYamlReader( self._testfile_stringstream_missing_keyword) self._testfile_streamreader_extra_keyword = ModelFunctionYamlReader( self._testfile_stringstream_extra_keyword) def test_write_to_roundtrip_stringstream(self): self._roundtrip_streamwriter.write() def test_read_from_testfile_stream(self): _read_model_function = self._testfile_streamreader.read() self.assertTrue(isinstance(_read_model_function, HistModelFunction)) self.assertTrue( np.allclose( _read_model_function.func(self._test_x, self._test_a, self._test_b), self._test_y)) def test_read_from_testfile_stream_missing_keyword(self): with self.assertRaises(YamlReaderException): self._testfile_streamreader_missing_keyword.read() def test_read_from_testfile_stream_extra_keyword(self): with self.assertRaises(YamlReaderException): self._testfile_streamreader_extra_keyword.read() def test_read_from_testfile_stream_with_formatter(self): _read_model_function = self._testfile_streamreader_with_formatter.read( ) self.assertTrue(isinstance(_read_model_function, HistModelFunction)) self.assertTrue( np.allclose( _read_model_function.func(self._test_x, self._test_a, self._test_b), self._test_y)) _read_formatter = _read_model_function.formatter self.assertTrue(isinstance(_read_formatter, ModelFunctionFormatter)) _read_arg_formatters = _read_formatter.arg_formatters self.assertTrue(_read_formatter.name == 'linear_model') self.assertTrue(_read_formatter.latex_name == 'linear model') self.assertTrue(_read_arg_formatters[0].name == 'x') self.assertTrue(_read_arg_formatters[0].latex_name == r'{r}') self.assertTrue(_read_arg_formatters[1].name == 'alpha') self.assertTrue(_read_arg_formatters[1].latex_name == r'{\alpha}') self.assertTrue(_read_arg_formatters[2].name == 'beta') self.assertTrue(_read_arg_formatters[2].latex_name == r'{\beta}') self.assertTrue( _read_formatter.expression_format_string == '{0} * {x} + {1}') self.assertTrue( _read_formatter.latex_expression_format_string == '{0}{x} + {1}') def test_round_trip_with_stringstream(self): self._roundtrip_streamwriter.write() self._roundtrip_stringstream.seek(0) # return to beginning _read_model_function = self._roundtrip_streamreader.read() self.assertTrue(isinstance(_read_model_function, HistModelFunction)) self.assertTrue( np.allclose( self._test_y, _read_model_function.func(self._test_x, self._test_a, self._test_b))) _given_formatter = self._model_function.formatter _read_formatter = _read_model_function.formatter self.assertTrue(isinstance(_read_formatter, ModelFunctionFormatter)) _given_arg_formatters = _given_formatter.arg_formatters _read_arg_formatters = _read_formatter.arg_formatters self.assertTrue(_read_formatter.name == _given_formatter.name) self.assertTrue( _read_formatter.latex_name == _given_formatter.latex_name) self.assertTrue( _read_arg_formatters[0].name == _given_arg_formatters[0].name) self.assertTrue(_read_arg_formatters[0].latex_name == _given_arg_formatters[0].latex_name) self.assertTrue( _read_arg_formatters[1].name == _given_arg_formatters[1].name) self.assertTrue(_read_arg_formatters[1].latex_name == _given_arg_formatters[1].latex_name) self.assertTrue( _read_arg_formatters[2].name == _given_arg_formatters[2].name) self.assertTrue(_read_arg_formatters[2].latex_name == _given_arg_formatters[2].latex_name) self.assertTrue(_read_formatter.expression_format_string == _given_formatter.expression_format_string) self.assertTrue(_read_formatter.latex_expression_format_string == _given_formatter.latex_expression_format_string)