示例#1
0
 def test_multiple_data_stats(self):
     with patch("sys.stdout", new=StringIO()) as out:
         input_data = np.array([[0, 1], [1, 2]])
         transform = DataStats()
         _ = DataStats()
         _ = transform(input_data)
         print(out.getvalue().strip())
示例#2
0
 def test_file(self, input_data, expected_print):
     with tempfile.TemporaryDirectory() as tempdir:
         filename = os.path.join(tempdir, "test_data_stats.log")
         handler = logging.FileHandler(filename, mode="w")
         handler.setLevel(logging.INFO)
         name = "DataStats"
         logger = logging.getLogger(name)
         logger.addHandler(handler)
         input_param = {
             "prefix": "test data",
             "data_type": True,
             "data_shape": True,
             "value_range": True,
             "data_value": True,
             "additional_info": np.mean,
             "name": name,
         }
         transform = DataStats(**input_param)
         _ = transform(input_data)
         for h in logger.handlers[:]:
             h.close()
             logger.removeHandler(h)
         with open(filename) as f:
             content = f.read()
         if sys.platform != "win32":
             self.assertEqual(content, expected_print)
示例#3
0
 def test_file(self, input_data, expected_print):
     with tempfile.TemporaryDirectory() as tempdir:
         filename = os.path.join(tempdir, "test_data_stats.log")
         handler = logging.FileHandler(filename, mode="w")
         input_param = {
             "prefix": "test data",
             "data_shape": True,
             "value_range": True,
             "data_value": True,
             "additional_info": lambda x: np.mean(x),
             "logger_handler": handler,
         }
         transform = DataStats(**input_param)
         _ = transform(input_data)
         handler.stream.close()
         transform._logger.removeHandler(handler)
         with open(filename, "r") as f:
             content = f.read()
         self.assertEqual(content, expected_print)
示例#4
0
 def test_value(self, input_param, input_data, expected_print):
     transform = DataStats(**input_param)
     _ = transform(input_data)
     self.assertEqual(transform.output, expected_print)
 def test_value(self, input_param, input_data, expected_print):
     transform = DataStats(**input_param)
     _ = transform(input_data)