def test_invalid_num_lines_leads_to_disabling_progress_bar(num_lines, _nncf_caplog): for _ in ProgressBar(TEST_RANGE, num_lines=num_lines): pass assert len(_nncf_caplog.records) == 1 record = next(iter(_nncf_caplog.records)) assert record.levelno == logging.WARNING
def test_invalid_total_leads_to_disabling_progress_bar(total, _nncf_caplog): for _ in ProgressBar(TEST_RANGE, total=total): pass assert len(_nncf_caplog.records) == 1 record = next(iter(_nncf_caplog.records)) assert record.levelno == logging.WARNING
def test_can_iterate_with_warning_for_iterable_without_len(_nncf_caplog): for _ in ProgressBar(enumerate(TEST_RANGE)): pass assert len(_nncf_caplog.records) == 1 record = next(iter(_nncf_caplog.records)) assert record.levelno == logging.WARNING
def test_can_limit_number_of_iterations(_nncf_caplog): for _ in ProgressBar(TEST_RANGE, total=2): pass assert _nncf_caplog.record_tuples == [ ('nncf', 20, ' ████████ | 1 / 2'), ('nncf', 20, ' ████████████████ | 2 / 2') ]
def test_can_print_by_default__with_enumerate_and_total(_nncf_caplog): for _ in ProgressBar(enumerate(TEST_RANGE), total=3.0): pass assert _nncf_caplog.record_tuples == [ ('nncf', 20, ' █████ | 1 / 3'), ('nncf', 20, ' ██████████ | 2 / 3'), ('nncf', 20, ' ████████████████ | 3 / 3') ]
def test_can_print_by_default(_nncf_caplog): for _ in ProgressBar(TEST_RANGE): pass assert _nncf_caplog.record_tuples == [ ('nncf', 20, ' █████ | 1 / 3'), ('nncf', 20, ' ██████████ | 2 / 3'), ('nncf', 20, ' ████████████████ | 3 / 3') ]
def test_can_print_collections_bigger_than_num_lines(_nncf_caplog): for _ in ProgressBar(range(11), num_lines=3): pass assert _nncf_caplog.record_tuples == [ ('nncf', 20, ' ███████ | 5 / 11'), ('nncf', 20, ' ██████████████ | 10 / 11'), ('nncf', 20, ' ████████████████ | 11 / 11') ]
def test_can_print_collections_less_than_num_lines(_nncf_caplog): desc = 'desc' for _ in ProgressBar(TEST_RANGE, desc=desc, num_lines=4): pass assert _nncf_caplog.record_tuples == [ ('nncf', 20, 'desc █████ | 1 / 3'), ('nncf', 20, 'desc ██████████ | 2 / 3'), ('nncf', 20, 'desc ████████████████ | 3 / 3') ]
def _run_model_inference(self, data_loader, num_init_steps, device): for i, loaded_item in ProgressBar( enumerate(data_loader), total=num_init_steps, desc=self.progressbar_description, ): if num_init_steps is not None and i >= num_init_steps: break args_kwargs_tuple = data_loader.get_inputs(loaded_item) self._infer_batch(args_kwargs_tuple, device)
def test_can_print_with_another_logger(_nncf_caplog): name = "test" logger = logging.getLogger(name) logger.setLevel(logging.INFO) for _ in ProgressBar(enumerate(TEST_RANGE), logger=logger, total=2): pass assert _nncf_caplog.record_tuples == [ ('test', 20, ' ████████ | 1 / 2'), ('test', 20, ' ████████████████ | 2 / 2') ]
def run(self, model: tf.keras.Model) -> None: """ Runs the batch-norm statistics adaptation algorithm. :param model: A model for which the algorithm will be applied. """ if self._device is not None: raise ValueError( 'TF implementation of batchnorm adaptation algorithm ' 'does not support switch of devices. Model initial device ' 'is used by default for batchnorm adaptation.') with BNTrainingStateSwitcher(model): for (x, _) in ProgressBar(islice(self._data_loader, self._num_bn_adaptation_steps), total=self._num_bn_adaptation_steps, desc='BatchNorm statistics adaptation'): model(x, training=True)
def run(self, model: tf.keras.Model) -> None: layer_statistics = [] op_statistics = [] handles = [] for layer in model.layers: if isinstance(layer, FakeQuantize): self._register_layer_statistics(layer, layer_statistics, handles) elif isinstance(layer, NNCFWrapper): self._register_op_statistics(layer, op_statistics, handles) for (x, _) in ProgressBar(islice(self.dataset, self.num_steps), total=self.num_steps, desc='Collecting tensor statistics/data'): model(x, training=False) for layer, collector in layer_statistics: target_stat = collector.get_statistics() minmax_stats = tf_convert_stat_to_min_max_tensor_stat(target_stat) layer.apply_range_initialization( tf.squeeze(minmax_stats.min_values), tf.squeeze(minmax_stats.max_values)) layer.enabled = True for layer, op_name, op, collector in op_statistics: weights = layer.get_operation_weights(op_name) target_stat = collector.get_statistics() minmax_stats = tf_convert_stat_to_min_max_tensor_stat(target_stat) min_values = minmax_stats.min_values if len(min_values.shape) != 1: min_values = tf.squeeze(min_values) max_values = minmax_stats.max_values if len(max_values.shape) != 1: max_values = tf.squeeze(max_values) op.apply_range_initialization(weights, min_values, max_values) op.enabled = True for handle in handles: handle.remove() for x, _ in self.dataset: model(x, training=False) break
def test_type_error_happens_for_iteration_none(_nncf_caplog): with pytest.raises(TypeError): for _ in ProgressBar(None): pass
def test_can_iterate_over_empty_iterable(caplog): for _ in ProgressBar([]): pass assert caplog.record_tuples == []