class ModelDatasetMethodsTest(ModelFittingTestCase): @classmethod def setUpClass(cls): cls.temp_dir_obj = TemporaryDirectory() cls.train_dataset = MNIST(cls.temp_dir_obj.name, train=True, download=True, transform=ToTensor()) cls.test_dataset = MNIST(cls.temp_dir_obj.name, train=False, download=True, transform=ToTensor()) cls.train_sub_dataset, cls.valid_sub_dataset = random_split( cls.train_dataset, [50_000, 10_000], generator=torch.Generator().manual_seed(42)) @classmethod def tearDownClass(cls): cls.temp_dir_obj.cleanup() def setUp(self): super().setUp() torch.manual_seed(42) self.pytorch_network = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) self.batch_metrics = ['accuracy'] self.batch_metrics_names = ['acc'] self.batch_metrics_values = [ANY] self.epoch_metrics = ['f1'] self.epoch_metrics_names = ['fscore_micro'] self.epoch_metrics_values = [ANY] self.model = Model(self.pytorch_network, 'sgd', 'cross_entropy', batch_metrics=self.batch_metrics, epoch_metrics=self.epoch_metrics) def assertStdoutContains(self, values): for value in values: self.assertIn(value, self.test_out.getvalue().strip()) def test_fitting_mnist(self): logs = self.model.fit_dataset( self.train_sub_dataset, self.valid_sub_dataset, epochs=ModelTest.epochs, steps_per_epoch=ModelTest.steps_per_epoch, validation_steps=ModelTest.steps_per_epoch, callbacks=[self.mock_callback]) params = { 'epochs': ModelTest.epochs, 'steps': ModelTest.steps_per_epoch, 'valid_steps': ModelTest.steps_per_epoch } self._test_callbacks_train(params, logs, valid_steps=ModelTest.steps_per_epoch) def test_fitting_mnist_without_valid(self): logs = self.model.fit_dataset( self.train_dataset, epochs=ModelTest.epochs, steps_per_epoch=ModelTest.steps_per_epoch, validation_steps=ModelTest.steps_per_epoch, callbacks=[self.mock_callback]) params = { 'epochs': ModelTest.epochs, 'steps': ModelTest.steps_per_epoch, 'valid_steps': ModelTest.steps_per_epoch } self._test_callbacks_train(params, logs, has_valid=False) def test_evaluate_dataset(self): num_steps = 10 loss, metrics, pred_y = self.model.evaluate_dataset( self.test_dataset, batch_size=ModelTest.batch_size, steps=num_steps, return_pred=True) self.assertEqual(type(loss), float) self.assertEqual(type(metrics), np.ndarray) self.assertEqual(metrics.tolist(), self.batch_metrics_values + self.epoch_metrics_values) self.assertEqual(type(pred_y), np.ndarray) self.assertEqual(pred_y.shape, (num_steps * ModelTest.batch_size, 10)) def test_evaluate_dataset_with_progress_bar_coloring(self): num_steps = 10 self._capture_output() self.model.evaluate_dataset(self.test_dataset, batch_size=ModelTest.batch_size, steps=num_steps) self.assertStdoutContains( ["%", "[32m", "[35m", "[36m", "[94m", "\u2588"]) def test_evaluate_dataset_with_callback(self): num_steps = 10 self.model.evaluate_dataset(self.test_dataset, batch_size=ModelTest.batch_size, steps=num_steps, callbacks=[self.mock_callback]) params = {'steps': ModelTest.epochs} self._test_callbacks_test(params) def test_evaluate_dataset_with_return_dict(self): num_steps = 10 logs = self.model.evaluate_dataset(self.test_dataset, batch_size=ModelTest.batch_size, steps=num_steps, return_dict_format=True) self._test_return_dict_logs(logs) def test_evaluate_dataset_with_ground_truth(self): num_steps = 10 loss, metrics, pred_y, true_y = self.model.evaluate_dataset( self.test_dataset, batch_size=ModelTest.batch_size, steps=num_steps, return_pred=True, return_ground_truth=True) self.assertEqual(type(loss), float) self.assertEqual(type(metrics), np.ndarray) self.assertEqual(metrics.tolist(), self.batch_metrics_values + self.epoch_metrics_values) self.assertEqual(type(pred_y), np.ndarray) self.assertEqual(type(true_y), np.ndarray) self.assertEqual(pred_y.shape, (num_steps * ModelTest.batch_size, 10)) self.assertEqual(true_y.shape, (num_steps * ModelTest.batch_size, )) def test_predict_dataset(self): class PredictDataset(Dataset): def __init__(self, dataset): super().__init__() self.dataset = dataset def __getitem__(self, index): return self.dataset[index][0] def __len__(self): return len(self.dataset) num_steps = 10 pred_y = self.model.predict_dataset(PredictDataset(self.test_dataset), batch_size=ModelTest.batch_size, steps=num_steps) self.assertEqual(type(pred_y), np.ndarray) self.assertEqual(pred_y.shape, (num_steps * ModelTest.batch_size, 10))
class ModelFittingTestCaseProgress(ModelFittingTestCase): # pylint: disable=too-many-public-methods num_steps = 5 TIME_REGEX = r"((([0-9]+d)?[0-9]{1,2}h)?[0-9]{1,2}m)?[0-9]{1,2}\.[0-9]{2}s" def setUp(self): super().setUp() self.train_generator = some_data_tensor_generator( ModelFittingTestCase.batch_size) self.valid_generator = some_data_tensor_generator( ModelFittingTestCase.batch_size) self.test_generator = some_data_tensor_generator( ModelFittingTestCase.batch_size) torch.manual_seed(42) self.pytorch_network = nn.Linear(1, 1) self.loss_function = nn.MSELoss() self.optimizer = torch.optim.SGD(self.pytorch_network.parameters(), lr=1e-3) self.batch_metrics = [ some_batch_metric_1, ('custom_name', some_batch_metric_2), repeat_batch_metric, repeat_batch_metric, ] self.batch_metrics_names = [ 'some_batch_metric_1', 'custom_name', 'repeat_batch_metric1', 'repeat_batch_metric2', ] self.batch_metrics_values = [ some_metric_1_value, some_metric_2_value, repeat_batch_metric_value, repeat_batch_metric_value, ] self.epoch_metrics = [SomeConstantEpochMetric()] self.epoch_metrics_names = ['some_constant_epoch_metric'] self.epoch_metrics_values = [some_constant_epoch_metric_value] self.model = Model( self.pytorch_network, self.optimizer, self.loss_function, batch_metrics=self.batch_metrics, epoch_metrics=self.epoch_metrics, ) self._capture_output() def assertStdoutContains(self, values): for value in values: self.assertIn(value, self.test_out.getvalue().strip()) def assertStdoutNotContains(self, values): for value in values: self.assertNotIn(value, self.test_out.getvalue().strip()) @skipIf(color is None, "Unable to import colorama") def test_fitting_with_default_coloring(self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], ) self.assertStdoutContains(["[32m", "[35m", "[36m", "[94m"]) def test_fitting_with_progress_bar_show_epoch(self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], ) self.assertStdoutContains(["Epoch", "1/5", "2/5"]) def test_fitting_with_progress_bar_show_steps(self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], ) self.assertStdoutContains( ["steps", f"{ModelFittingTestCase.steps_per_epoch}"]) def test_fitting_with_progress_bar_show_train_val_final_steps(self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], ) self.assertStdoutContains(["Val steps", "Train steps"]) def test_fitting_with_no_progress_bar_dont_show_epoch(self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], verbose=False, ) self.assertStdoutNotContains(["Epoch", "1/5", "2/5"]) @skipIf(color is None, "Unable to import colorama") def test_fitting_with_user_coloring(self): coloring = { "text_color": 'BLACK', "ratio_color": "BLACK", "metric_value_color": "BLACK", "time_color": "BLACK", "progress_bar_color": "BLACK", } _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], progress_options=dict(coloring=coloring), ) self.assertStdoutContains(["[30m"]) @skipIf(color is None, "Unable to import colorama") def test_fitting_with_user_partial_coloring(self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], progress_options=dict(coloring={ "text_color": 'BLACK', "ratio_color": "BLACK" }), ) self.assertStdoutContains(["[30m", "[32m", "[35m", "[94m"]) def test_fitting_with_user_coloring_invalid(self): with self.assertRaises(KeyError): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], progress_options=dict(coloring={"invalid_name": 'A COLOR'}), ) def test_fitting_with_no_coloring(self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], progress_options=dict(coloring=False), ) self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"]) @skipIf(color is None, "Unable to import colorama") def test_fitting_with_progress_bar_default_color(self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], progress_options=dict(coloring=True, progress_bar=True), ) self.assertStdoutContains( ["%", "[32m", "[35m", "[36m", "[94m", "\u2588"]) @skipIf(color is None, "Unable to import colorama") def test_fitting_with_progress_bar_user_color(self): coloring = { "text_color": 'BLACK', "ratio_color": "BLACK", "metric_value_color": "BLACK", "time_color": "BLACK", "progress_bar_color": "BLACK", } _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], progress_options=dict(coloring=coloring, progress_bar=True), ) self.assertStdoutContains(["%", "[30m", "\u2588"]) def test_fitting_with_progress_bar_no_color(self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], progress_options=dict(coloring=False, progress_bar=True), ) self.assertStdoutContains(["%", "\u2588"]) self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"]) def test_fitting_with_no_progress_bar(self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=ModelFittingTestCase.epochs, steps_per_epoch=ModelFittingTestCase.steps_per_epoch, validation_steps=ModelFittingTestCase.steps_per_epoch, callbacks=[self.mock_callback], progress_options=dict(coloring=False, progress_bar=False), ) self.assertStdoutNotContains(["%", "\u2588"]) self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"]) def test_progress_bar_with_step_is_none(self): train_generator = SomeDataGeneratorUsingStopIteration( ModelFittingTestCase.batch_size, 10) valid_generator = SomeDataGeneratorUsingStopIteration( ModelFittingTestCase.batch_size, 10) _ = self.model.fit_generator( train_generator, valid_generator, epochs=ModelFittingTestCase.epochs, progress_options=dict(coloring=False, progress_bar=True), ) self.assertStdoutContains(["s/step"]) self.assertStdoutNotContains( ["[32m", "[35m", "[36m", "[94m", "\u2588", "%"]) def test_evaluate_without_progress_output(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) _, _ = self.model.evaluate(x, y, batch_size=ModelFittingTestCase.batch_size, verbose=False) self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"]) @skipIf(color is None, "Unable to import colorama") def test_evaluate_with_default_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) _, _ = self.model.evaluate(x, y, batch_size=ModelFittingTestCase.batch_size) self.assertStdoutContains(["[32m", "[35m", "[36m", "[94m"]) @skipIf(color is None, "Unable to import colorama") def test_evaluate_with_user_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) coloring = { "text_color": 'BLACK', "ratio_color": "BLACK", "metric_value_color": "BLACK", "time_color": "BLACK", "progress_bar_color": "BLACK", } _, _ = self.model.evaluate(x, y, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring=coloring)) self.assertStdoutContains(["[30m"]) @skipIf(color is None, "Unable to import colorama") def test_evaluate_with_user_partial_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) _, _ = self.model.evaluate( x, y, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring={ "text_color": 'BLACK', "ratio_color": "BLACK" }), ) self.assertStdoutContains(["[30m", "[32m", "[35m", "[94m"]) def test_evaluate_with_user_coloring_invalid(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) with self.assertRaises(KeyError): _, _ = self.model.evaluate( x, y, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring={"invalid_name": 'A COLOR'}), ) def test_evaluate_with_no_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) _, _ = self.model.evaluate(x, y, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring=False)) self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"]) @skipIf(color is None, "Unable to import colorama") def test_evaluate_with_progress_bar_default_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) _, _ = self.model.evaluate(x, y, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring=True, progress_bar=True)) self.assertStdoutContains( ["%", "[32m", "[35m", "[36m", "[94m", "\u2588"]) @skipIf(color is None, "Unable to import colorama") def test_evaluate_with_progress_bar_user_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) coloring = { "text_color": 'BLACK', "ratio_color": "BLACK", "metric_value_color": "BLACK", "time_color": "BLACK", "progress_bar_color": "BLACK", } _, _ = self.model.evaluate( x, y, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring=coloring, progress_bar=True), ) self.assertStdoutContains(["%", "[30m", "\u2588"]) @skipIf(color is None, "Unable to import colorama") def test_evaluate_with_progress_bar_user_no_color(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) _, _ = self.model.evaluate(x, y, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring=False, progress_bar=True)) self.assertStdoutContains(["%", "\u2588"]) self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"]) def test_evaluate_with_no_progress_bar(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) _, _ = self.model.evaluate(x, y, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring=False, progress_bar=False)) self.assertStdoutNotContains(["%", "\u2588"]) self.assertStdoutNotContains(["[32m", "[35m", "[36m", "[94m"]) def test_evaluate_data_loader_with_progress_bar_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) y = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) dataset = TensorDataset(x, y) generator = DataLoader(dataset, ModelFittingTestCase.batch_size) _, _ = self.model.evaluate_generator(generator, verbose=True) self.assertStdoutContains( ["%", "[32m", "[35m", "[36m", "[94m", "\u2588"]) def test_evaluate_generator_with_progress_bar_coloring(self): generator = some_data_tensor_generator(ModelFittingTestCase.batch_size) _, _ = self.model.evaluate_generator( generator, steps=ModelFittingTestCaseProgress.num_steps, verbose=True) self.assertStdoutContains( ["%", "[32m", "[35m", "[36m", "[94m", "\u2588"]) def test_evaluate_generator_with_callback_and_progress_bar_coloring(self): generator = some_data_tensor_generator(ModelFittingTestCase.batch_size) _, _ = self.model.evaluate_generator( generator, steps=ModelFittingTestCaseProgress.num_steps, callbacks=[self.mock_callback], verbose=True) self.assertStdoutContains( ["%", "[32m", "[35m", "[36m", "[94m", "\u2588"]) def test_fitting_complete_display_test_with_progress_bar_coloring(self): # we use the same color for all components for simplicity coloring = { "text_color": 'WHITE', "ratio_color": "WHITE", "metric_value_color": "WHITE", "time_color": "WHITE", "progress_bar_color": "WHITE", } _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=1, steps_per_epoch=ModelFittingTestCaseProgress.num_steps, validation_steps=ModelFittingTestCaseProgress.num_steps, callbacks=[self.mock_callback], progress_options=dict(coloring=coloring, progress_bar=False), ) # We split per step update steps_update = self.test_out.getvalue().strip().split("\r") # we don't validate the templating of metrics since tested before template_format = r".*Epoch:.*{}\/1.*\[37mStep:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:" epoch = 1 # the 5 train steps for step, step_update in enumerate( steps_update[:ModelFittingTestCaseProgress.num_steps]): step += 1 progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2) regex_filled = template_format.format( epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100, progress_Bar) self.assertRegex(step_update, regex_filled) # The 5 val steps for step, step_update in enumerate( steps_update[ModelFittingTestCaseProgress.num_steps:-1]): step += 1 progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2) regex_filled = template_format.format( epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100, progress_Bar) self.assertRegex(step_update, regex_filled) # last print update templating different last_print_regex = r".*\[37mTrain steps:.*5.*Val steps:.*5.*" + ModelFittingTestCaseProgress.TIME_REGEX self.assertRegex(steps_update[-1], last_print_regex) def test_fitting_complete_display_test_with_progress_bar_no_coloring(self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=1, steps_per_epoch=ModelFittingTestCaseProgress.num_steps, validation_steps=ModelFittingTestCaseProgress.num_steps, callbacks=[self.mock_callback], progress_options=dict(coloring=False, progress_bar=True), ) # We split per step update steps_update = self.test_out.getvalue().strip().split("\r") # we don't validate the templating of metrics since tested before template_format = r".*Epoch:.*{}\/1.*Step:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:" epoch = 1 # the 5 train steps for step, step_update in enumerate( steps_update[:ModelFittingTestCaseProgress.num_steps]): step += 1 progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2) regex_filled = template_format.format( epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100, progress_Bar) self.assertRegex(step_update, regex_filled) # The 5 val steps for step, step_update in enumerate( steps_update[ModelFittingTestCaseProgress.num_steps:-1]): step += 1 progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2) regex_filled = template_format.format( epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100, progress_Bar) self.assertRegex(step_update, regex_filled) # last print update templating different last_print_regex = r".*Train steps:.*5.*Val steps:.*5.*" + ModelFittingTestCaseProgress.TIME_REGEX self.assertRegex(steps_update[-1], last_print_regex) def test_fitting_complete_display_test_with_no_progress_bar_no_coloring( self): _ = self.model.fit_generator( self.train_generator, self.valid_generator, epochs=1, steps_per_epoch=ModelFittingTestCaseProgress.num_steps, validation_steps=ModelFittingTestCaseProgress.num_steps, callbacks=[self.mock_callback], progress_options=dict(coloring=False, progress_bar=False), ) # We split per step update steps_update = self.test_out.getvalue().strip().split("\r") # we don't validate the templating of metrics since tested before template_format = r".*Epoch:.*{}\/1.*Step:.*{}\/5.*ETA:" epoch = 1 # the 5 train steps for step, step_update in enumerate( steps_update[:ModelFittingTestCaseProgress.num_steps]): step += 1 regex_filled = template_format.format( epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100) self.assertRegex(step_update, regex_filled) # The 5 val steps for step, step_update in enumerate( steps_update[ModelFittingTestCaseProgress.num_steps:-1]): step += 1 regex_filled = template_format.format( epoch, step, step / ModelFittingTestCaseProgress.num_steps * 100) self.assertRegex(step_update, regex_filled) # last print update templating different last_print_regex = r".*Train steps:.*5.*Val steps:.*5.*" + ModelFittingTestCaseProgress.TIME_REGEX self.assertRegex(steps_update[-1], last_print_regex) def test_evaluate_complete_display_test_with_progress_bar_coloring(self): # we use the same color for all components for simplicity coloring = { "text_color": 'WHITE', "ratio_color": "WHITE", "metric_value_color": "WHITE", "time_color": "WHITE", "progress_bar_color": "WHITE", } _, _ = self.model.evaluate_generator( self.test_generator, steps=ModelFittingTestCaseProgress.num_steps, callbacks=[self.mock_callback], verbose=True, progress_options=dict(coloring=coloring, progress_bar=True), ) # We split per step update steps_update = self.test_out.getvalue().strip().split("\r") # we don't validate the templating of metrics since tested before template_format = r".*\[37mStep:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:" for step, step_update in enumerate(steps_update[:-1]): step += 1 progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2) regex_filled = template_format.format( step, step / ModelFittingTestCaseProgress.num_steps * 100, progress_Bar) self.assertRegex(step_update, regex_filled) # last print update templating different last_print_regex = r".*\[37mTest steps:.*5.*" + ModelFittingTestCaseProgress.TIME_REGEX self.assertRegex(steps_update[-1], last_print_regex) def test_evaluate_complete_display_test_with_progress_bar_no_coloring( self): _, _ = self.model.evaluate_generator( self.test_generator, steps=ModelFittingTestCaseProgress.num_steps, callbacks=[self.mock_callback], verbose=True, progress_options=dict(coloring=False, progress_bar=True), ) # We split per step update steps_update = self.test_out.getvalue().strip().split("\r") # we don't validate the templating of metrics since tested before template_format = r".*Step:.*{}\/5.*{:6.2f}\%.*|{}|.*ETA:" for step, step_update in enumerate(steps_update[:-1]): step += 1 progress_Bar = "\u2588" * step * 2 + " " * (20 - step * 2) regex_filled = template_format.format( step, step / ModelFittingTestCaseProgress.num_steps * 100, progress_Bar) self.assertRegex(step_update, regex_filled) # last print update templating different last_print_regex = r".*Test steps:.*5.*" + ModelFittingTestCaseProgress.TIME_REGEX self.assertRegex(steps_update[-1], last_print_regex) def test_evaluate_complete_display_test_with_no_progress_bar_no_coloring( self): _, _ = self.model.evaluate_generator( self.test_generator, steps=ModelFittingTestCaseProgress.num_steps, callbacks=[self.mock_callback], verbose=True, progress_options=dict(coloring=False, progress_bar=False), ) # We split per step update steps_update = self.test_out.getvalue().strip().split("\r") # we don't validate the templating of metrics since tested before template_format = r".*Step:.*{}\/5.*ETA:" for step, step_update in enumerate(steps_update[:-1]): step += 1 regex_filled = template_format.format( step, step / ModelFittingTestCaseProgress.num_steps * 100) self.assertRegex(step_update, regex_filled) # last print update templating different last_print_regex = r".*Test steps:.*5.*" + ModelFittingTestCaseProgress.TIME_REGEX self.assertRegex(steps_update[-1], last_print_regex) @skipIf(color is None, "Unable to import colorama") def test_predict_dataset_with_default_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) self.model.predict_dataset(x) self.assertStdoutContains(["[32m", "[35m", "[36m"]) @skipIf(color is None, "Unable to import colorama") def test_predict_dataset_with_user_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) coloring = { "text_color": 'BLACK', "ratio_color": "BLACK", "metric_value_color": "BLACK", "time_color": "BLACK", "progress_bar_color": "BLACK", } self.model.predict_dataset(x, progress_options=dict(coloring=coloring, progress_bar=True)) self.assertStdoutContains(["[30m"]) def test_predict_dataset_with_user_coloring_invalid(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) with self.assertRaises(KeyError): self.model.predict_dataset( x, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring={"invalid_name": 'A COLOR'}), ) def test_predict_dataset_with_no_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) self.model.predict_dataset(x, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring=False)) self.assertStdoutNotContains(["[32m", "[35m", "[36m"]) @skipIf(color is None, "Unable to import colorama") def test_predict_dataset_with_progress_bar_default_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) self.model.predict_dataset(x, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring=True, progress_bar=True)) self.assertStdoutContains(["%", "[32m", "[35m", "[36m", "\u2588"]) @skipIf(color is None, "Unable to import colorama") def test_predict_dataset_with_progress_bar_user_coloring(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) coloring = { "text_color": 'BLACK', "ratio_color": "BLACK", "metric_value_color": "BLACK", "time_color": "BLACK", "progress_bar_color": "BLACK", } self.model.predict_dataset(x, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring=coloring, progress_bar=True)) self.assertStdoutContains(["%", "[30m", "\u2588"]) @skipIf(color is None, "Unable to import colorama") def test_predict_dataset_with_progress_bar_user_no_color(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) self.model.predict_dataset(x, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring=False, progress_bar=True)) self.assertStdoutContains(["%", "\u2588"]) self.assertStdoutNotContains(["[32m", "[35m", "[36m"]) def test_predict_dataset_with_no_progress_bar(self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) self.model.predict_dataset(x, batch_size=ModelFittingTestCase.batch_size, progress_options=dict(coloring=False, progress_bar=False)) self.assertStdoutNotContains(["%", "\u2588"]) self.assertStdoutNotContains(["[32m", "[35m", "[36m"]) def test_predict_dataset_complete_display_predict_with_progress_bar_coloring( self): x = torch.rand(ModelFittingTestCase.evaluate_dataset_len, 1) # we use the same color for all components for simplicity coloring = { "text_color": 'WHITE', "ratio_color": "WHITE", "metric_value_color": "WHITE", "time_color": "WHITE", "progress_bar_color": "WHITE", } self.model.predict_dataset(x, verbose=True, progress_options=dict(coloring=coloring, progress_bar=True)) # We split per step update steps_update = self.test_out.getvalue().strip().split("\r") # last print update templating different last_print_regex = r".*\[37mPrediction steps:.*" + ModelFittingTestCaseProgress.TIME_REGEX self.assertRegex(steps_update[-1], last_print_regex)