Exemplo n.º 1
0
def test_invalid_num_lines_leads_to_disabling_progress_bar(num_lines, caplog):
    for _ in ProgressBar(TEST_RANGE, num_lines=num_lines):
        pass

    assert len(caplog.records) == 1
    record = next(iter(caplog.records))
    assert record.levelno == logging.WARNING
Exemplo n.º 2
0
def test_invalid_total_leads_to_disabling_progress_bar(total, caplog):
    for _ in ProgressBar(TEST_RANGE, total=total):
        pass

    assert len(caplog.records) == 1
    record = next(iter(caplog.records))
    assert record.levelno == logging.WARNING
Exemplo n.º 3
0
def test_can_print_by_default__with_enumerate_and_total(caplog):
    for _ in ProgressBar(enumerate(TEST_RANGE), total=3.0):
        pass

    assert caplog.record_tuples == [('nncf', 20, ' █████             | 1 / 3'),
                                    ('nncf', 20, ' ██████████        | 2 / 3'),
                                    ('nncf', 20, ' ████████████████  | 3 / 3')]
Exemplo n.º 4
0
def test_can_print_by_default(caplog):
    for _ in ProgressBar(TEST_RANGE):
        pass

    assert caplog.record_tuples == [('nncf', 20, ' █████             | 1 / 3'),
                                    ('nncf', 20, ' ██████████        | 2 / 3'),
                                    ('nncf', 20, ' ████████████████  | 3 / 3')]
Exemplo n.º 5
0
def test_can_iterate_with_warning_for_iterable_without_len(caplog):
    for _ in ProgressBar(enumerate(TEST_RANGE)):
        pass

    assert len(caplog.records) == 1
    record = next(iter(caplog.records))
    assert record.levelno == logging.WARNING
Exemplo n.º 6
0
def test_can_print_collections_bigger_than_num_lines(caplog):
    for _ in ProgressBar(range(11), num_lines=3):
        pass

    assert caplog.record_tuples == [
        ('nncf', 20, ' ███████           | 5 / 11'),
        ('nncf', 20, ' ██████████████    | 10 / 11'),
        ('nncf', 20, ' ████████████████  | 11 / 11')
    ]
Exemplo n.º 7
0
def test_can_print_with_another_logger(caplog):
    name = "test"
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    for _ in ProgressBar(enumerate(TEST_RANGE), logger=logger, total=2):
        pass

    assert caplog.record_tuples == [('test', 20, ' ████████          | 1 / 2'),
                                    ('test', 20, ' ████████████████  | 2 / 2')]
Exemplo n.º 8
0
def test_can_print_collections_less_than_num_lines(caplog):
    desc = 'desc'
    for _ in ProgressBar(TEST_RANGE, desc=desc, num_lines=4):
        pass

    assert caplog.record_tuples == [
        ('nncf', 20, 'desc █████             | 1 / 3'),
        ('nncf', 20, 'desc ██████████        | 2 / 3'),
        ('nncf', 20, 'desc ████████████████  | 3 / 3')
    ]
Exemplo n.º 9
0
 def _run_model_inference(self, data_loader, num_init_steps, device):
     for i, loaded_item in ProgressBar(
             enumerate(data_loader),
             total=num_init_steps,
             desc=self.progressbar_description,
     ):
         if num_init_steps is not None and i >= num_init_steps:
             break
         args_kwargs_tuple = data_loader.get_inputs(loaded_item)
         self._infer_batch(args_kwargs_tuple, device)
Exemplo n.º 10
0
    def _run_model_inference(self, data_loader, num_init_steps, device):
        num_bn_forget_steps = self.num_bn_forget_steps

        def set_bn_momentum(module, momentum_value):
            module.momentum = momentum_value

        def save_original_bn_momenta(module):
            self.original_momenta_values[module] = module.momentum

        def restore_original_bn_momenta(module):
            module.momentum = self.original_momenta_values[module]

        with training_mode_switcher(self.model, is_training=True):
            self.model.apply(
                self._apply_to_batchnorms(save_original_bn_momenta))
            self.model.apply(
                self._apply_to_batchnorms(
                    partial(set_bn_momentum,
                            momentum_value=self.momentum_bn_forget)))

            for i, loaded_item in enumerate(data_loader):
                if num_bn_forget_steps is not None and i >= num_bn_forget_steps:
                    break
                args_kwargs_tuple = data_loader.get_inputs(loaded_item)
                self._infer_batch(args_kwargs_tuple, device)

            self.model.apply(
                self._apply_to_batchnorms(restore_original_bn_momenta))

            for i, loaded_item in ProgressBar(
                    enumerate(data_loader),
                    total=num_init_steps,
                    desc=self.progressbar_description):
                if num_init_steps is not None and i >= num_init_steps:
                    break
                args_kwargs_tuple = data_loader.get_inputs(loaded_item)
                self._infer_batch(args_kwargs_tuple, device)
Exemplo n.º 11
0
def test_type_error_happens_for_iteration_none(caplog):
    with pytest.raises(TypeError):
        for _ in ProgressBar(None):
            pass
Exemplo n.º 12
0
def test_can_iterate_over_empty_iterable(caplog):
    for _ in ProgressBar([]):
        pass

    assert caplog.record_tuples == []
Exemplo n.º 13
0
def test_can_limit_number_of_iterations(caplog):
    for _ in ProgressBar(TEST_RANGE, total=2):
        pass

    assert caplog.record_tuples == [('nncf', 20, ' ████████          | 1 / 2'),
                                    ('nncf', 20, ' ████████████████  | 2 / 2')]