def test_invalid_num_lines_leads_to_disabling_progress_bar(num_lines, caplog): for _ in ProgressBar(TEST_RANGE, num_lines=num_lines): pass assert len(caplog.records) == 1 record = next(iter(caplog.records)) assert record.levelno == logging.WARNING
def test_invalid_total_leads_to_disabling_progress_bar(total, caplog): for _ in ProgressBar(TEST_RANGE, total=total): pass assert len(caplog.records) == 1 record = next(iter(caplog.records)) assert record.levelno == logging.WARNING
def test_can_print_by_default__with_enumerate_and_total(caplog): for _ in ProgressBar(enumerate(TEST_RANGE), total=3.0): pass assert caplog.record_tuples == [('nncf', 20, ' █████ | 1 / 3'), ('nncf', 20, ' ██████████ | 2 / 3'), ('nncf', 20, ' ████████████████ | 3 / 3')]
def test_can_print_by_default(caplog): for _ in ProgressBar(TEST_RANGE): pass assert caplog.record_tuples == [('nncf', 20, ' █████ | 1 / 3'), ('nncf', 20, ' ██████████ | 2 / 3'), ('nncf', 20, ' ████████████████ | 3 / 3')]
def test_can_iterate_with_warning_for_iterable_without_len(caplog): for _ in ProgressBar(enumerate(TEST_RANGE)): pass assert len(caplog.records) == 1 record = next(iter(caplog.records)) assert record.levelno == logging.WARNING
def test_can_print_collections_bigger_than_num_lines(caplog): for _ in ProgressBar(range(11), num_lines=3): pass assert caplog.record_tuples == [ ('nncf', 20, ' ███████ | 5 / 11'), ('nncf', 20, ' ██████████████ | 10 / 11'), ('nncf', 20, ' ████████████████ | 11 / 11') ]
def test_can_print_with_another_logger(caplog): name = "test" logger = logging.getLogger(name) logger.setLevel(logging.INFO) for _ in ProgressBar(enumerate(TEST_RANGE), logger=logger, total=2): pass assert caplog.record_tuples == [('test', 20, ' ████████ | 1 / 2'), ('test', 20, ' ████████████████ | 2 / 2')]
def test_can_print_collections_less_than_num_lines(caplog): desc = 'desc' for _ in ProgressBar(TEST_RANGE, desc=desc, num_lines=4): pass assert caplog.record_tuples == [ ('nncf', 20, 'desc █████ | 1 / 3'), ('nncf', 20, 'desc ██████████ | 2 / 3'), ('nncf', 20, 'desc ████████████████ | 3 / 3') ]
def _run_model_inference(self, data_loader, num_init_steps, device): for i, loaded_item in ProgressBar( enumerate(data_loader), total=num_init_steps, desc=self.progressbar_description, ): if num_init_steps is not None and i >= num_init_steps: break args_kwargs_tuple = data_loader.get_inputs(loaded_item) self._infer_batch(args_kwargs_tuple, device)
def _run_model_inference(self, data_loader, num_init_steps, device): num_bn_forget_steps = self.num_bn_forget_steps def set_bn_momentum(module, momentum_value): module.momentum = momentum_value def save_original_bn_momenta(module): self.original_momenta_values[module] = module.momentum def restore_original_bn_momenta(module): module.momentum = self.original_momenta_values[module] with training_mode_switcher(self.model, is_training=True): self.model.apply( self._apply_to_batchnorms(save_original_bn_momenta)) self.model.apply( self._apply_to_batchnorms( partial(set_bn_momentum, momentum_value=self.momentum_bn_forget))) for i, loaded_item in enumerate(data_loader): if num_bn_forget_steps is not None and i >= num_bn_forget_steps: break args_kwargs_tuple = data_loader.get_inputs(loaded_item) self._infer_batch(args_kwargs_tuple, device) self.model.apply( self._apply_to_batchnorms(restore_original_bn_momenta)) for i, loaded_item in ProgressBar( enumerate(data_loader), total=num_init_steps, desc=self.progressbar_description): if num_init_steps is not None and i >= num_init_steps: break args_kwargs_tuple = data_loader.get_inputs(loaded_item) self._infer_batch(args_kwargs_tuple, device)
def test_type_error_happens_for_iteration_none(caplog): with pytest.raises(TypeError): for _ in ProgressBar(None): pass
def test_can_iterate_over_empty_iterable(caplog): for _ in ProgressBar([]): pass assert caplog.record_tuples == []
def test_can_limit_number_of_iterations(caplog): for _ in ProgressBar(TEST_RANGE, total=2): pass assert caplog.record_tuples == [('nncf', 20, ' ████████ | 1 / 2'), ('nncf', 20, ' ████████████████ | 2 / 2')]