Пример #1
0
def test_invalid_num_lines_leads_to_disabling_progress_bar(num_lines, _nncf_caplog):
    for _ in ProgressBar(TEST_RANGE, num_lines=num_lines):
        pass

    assert len(_nncf_caplog.records) == 1
    record = next(iter(_nncf_caplog.records))
    assert record.levelno == logging.WARNING
Пример #2
0
def test_invalid_total_leads_to_disabling_progress_bar(total, _nncf_caplog):
    for _ in ProgressBar(TEST_RANGE, total=total):
        pass

    assert len(_nncf_caplog.records) == 1
    record = next(iter(_nncf_caplog.records))
    assert record.levelno == logging.WARNING
Пример #3
0
def test_can_iterate_with_warning_for_iterable_without_len(_nncf_caplog):
    for _ in ProgressBar(enumerate(TEST_RANGE)):
        pass

    assert len(_nncf_caplog.records) == 1
    record = next(iter(_nncf_caplog.records))
    assert record.levelno == logging.WARNING
Пример #4
0
def test_can_limit_number_of_iterations(_nncf_caplog):
    for _ in ProgressBar(TEST_RANGE, total=2):
        pass

    assert _nncf_caplog.record_tuples == [
        ('nncf', 20, ' ████████          | 1 / 2'),
        ('nncf', 20, ' ████████████████  | 2 / 2')
    ]
Пример #5
0
def test_can_print_by_default__with_enumerate_and_total(_nncf_caplog):
    for _ in ProgressBar(enumerate(TEST_RANGE), total=3.0):
        pass

    assert _nncf_caplog.record_tuples == [
        ('nncf', 20, ' █████             | 1 / 3'),
        ('nncf', 20, ' ██████████        | 2 / 3'),
        ('nncf', 20, ' ████████████████  | 3 / 3')
    ]
Пример #6
0
def test_can_print_by_default(_nncf_caplog):
    for _ in ProgressBar(TEST_RANGE):
        pass

    assert _nncf_caplog.record_tuples == [
        ('nncf', 20, ' █████             | 1 / 3'),
        ('nncf', 20, ' ██████████        | 2 / 3'),
        ('nncf', 20, ' ████████████████  | 3 / 3')
    ]
Пример #7
0
def test_can_print_collections_bigger_than_num_lines(_nncf_caplog):
    for _ in ProgressBar(range(11), num_lines=3):
        pass

    assert _nncf_caplog.record_tuples == [
        ('nncf', 20, ' ███████           | 5 / 11'),
        ('nncf', 20, ' ██████████████    | 10 / 11'),
        ('nncf', 20, ' ████████████████  | 11 / 11')
    ]
Пример #8
0
def test_can_print_collections_less_than_num_lines(_nncf_caplog):
    desc = 'desc'
    for _ in ProgressBar(TEST_RANGE, desc=desc, num_lines=4):
        pass

    assert _nncf_caplog.record_tuples == [
        ('nncf', 20, 'desc █████             | 1 / 3'),
        ('nncf', 20, 'desc ██████████        | 2 / 3'),
        ('nncf', 20, 'desc ████████████████  | 3 / 3')
    ]
Пример #9
0
 def _run_model_inference(self, data_loader, num_init_steps, device):
     for i, loaded_item in ProgressBar(
             enumerate(data_loader),
             total=num_init_steps,
             desc=self.progressbar_description,
     ):
         if num_init_steps is not None and i >= num_init_steps:
             break
         args_kwargs_tuple = data_loader.get_inputs(loaded_item)
         self._infer_batch(args_kwargs_tuple, device)
Пример #10
0
def test_can_print_with_another_logger(_nncf_caplog):
    name = "test"
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    for _ in ProgressBar(enumerate(TEST_RANGE), logger=logger, total=2):
        pass

    assert _nncf_caplog.record_tuples == [
        ('test', 20, ' ████████          | 1 / 2'),
        ('test', 20, ' ████████████████  | 2 / 2')
    ]
Пример #11
0
    def run(self, model: tf.keras.Model) -> None:
        """
        Runs the batch-norm statistics adaptation algorithm.

        :param model: A model for which the algorithm will be applied.
        """
        if self._device is not None:
            raise ValueError(
                'TF implementation of batchnorm adaptation algorithm '
                'does not support switch of devices. Model initial device '
                'is used by default for batchnorm adaptation.')
        with BNTrainingStateSwitcher(model):
            for (x, _) in ProgressBar(islice(self._data_loader,
                                             self._num_bn_adaptation_steps),
                                      total=self._num_bn_adaptation_steps,
                                      desc='BatchNorm statistics adaptation'):
                model(x, training=True)
Пример #12
0
    def run(self, model: tf.keras.Model) -> None:
        layer_statistics = []
        op_statistics = []
        handles = []
        for layer in model.layers:
            if isinstance(layer, FakeQuantize):
                self._register_layer_statistics(layer, layer_statistics,
                                                handles)
            elif isinstance(layer, NNCFWrapper):
                self._register_op_statistics(layer, op_statistics, handles)

        for (x, _) in ProgressBar(islice(self.dataset, self.num_steps),
                                  total=self.num_steps,
                                  desc='Collecting tensor statistics/data'):
            model(x, training=False)

        for layer, collector in layer_statistics:
            target_stat = collector.get_statistics()
            minmax_stats = tf_convert_stat_to_min_max_tensor_stat(target_stat)
            layer.apply_range_initialization(
                tf.squeeze(minmax_stats.min_values),
                tf.squeeze(minmax_stats.max_values))
            layer.enabled = True

        for layer, op_name, op, collector in op_statistics:
            weights = layer.get_operation_weights(op_name)
            target_stat = collector.get_statistics()
            minmax_stats = tf_convert_stat_to_min_max_tensor_stat(target_stat)
            min_values = minmax_stats.min_values
            if len(min_values.shape) != 1:
                min_values = tf.squeeze(min_values)
            max_values = minmax_stats.max_values
            if len(max_values.shape) != 1:
                max_values = tf.squeeze(max_values)
            op.apply_range_initialization(weights, min_values, max_values)
            op.enabled = True

        for handle in handles:
            handle.remove()

        for x, _ in self.dataset:
            model(x, training=False)
            break
Пример #13
0
def test_type_error_happens_for_iteration_none(_nncf_caplog):
    with pytest.raises(TypeError):
        for _ in ProgressBar(None):
            pass
Пример #14
0
def test_can_iterate_over_empty_iterable(caplog):
    for _ in ProgressBar([]):
        pass

    assert caplog.record_tuples == []