Exemplo n.º 1
0
def test_single_grad_check():
    if sys.platform.startswith('win'):
        pytest.skip(
            'this doctest does not work on Windows, '
            'training is not possible on Windows due to symlinks being unavailable'
        )
    tr_dataset, dt_dataset = get_datasets()

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)
        t = pt.Trainer(Model(),
                       optimizer=pt.optimizer.Adam(),
                       storage_dir=tmp_dir,
                       stop_trigger=(2, 'epoch'))
        t.test_run(tr_dataset, dt_dataset)

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)
        t = pt.Trainer(ZeroGradModel(),
                       optimizer=pt.optimizer.Adam(),
                       storage_dir=tmp_dir,
                       stop_trigger=(2, 'epoch'))
        # AssertionError: The loss of the model did not change between two validations.
        with pytest.raises(AssertionError):
            t.test_run(tr_dataset, dt_dataset)
Exemplo n.º 2
0
def test_backoff():
    if sys.platform.startswith('win'):
        pytest.skip('this doctest does not work on Windows, '
                    'training is not possible on Windows due to symlinks being unavailable')
    ds = [0]
    with tempfile.TemporaryDirectory() as tmp_dir:
        optimizer = pt.optimizer.Adam()
        model = DummyModel([3, 2, 1, 0, 1, 1, 1, 1, 1, 1], tmp_dir, optimizer)
        trainer = pt.Trainer(
            model, tmp_dir, optimizer, stop_trigger=(10, 'epoch')
        )
        trainer.register_validation_hook(
            ds, max_checkpoints=None,
            n_back_off=1, back_off_patience=2, early_stopping_patience=2
        )
        trainer.train(ds)
    assert model.ckpt_log == [
        [],
        ['ckpt_0.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_4.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_4.pth', 'ckpt_5.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_4.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_4.pth', 'ckpt_5.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
    ]
    assert model.lr_log == 7*[0.001]+3*[0.0001]
Exemplo n.º 3
0
def test_validation_hook_modify_summary_training_flag():
    class Model(DummyModel):
        def review(self, example, output):
            summary = super().review(example, output)
            summary.setdefault('scalars', {})['training'] = self.training
            return summary

        def modify_summary(self, summary):
            if len(summary['scalars']['training']) == 0:
                # The first validation triggers to write the training summary,
                # which is empty.
                assert self.training is True
            else:
                assert set(summary['scalars']['training']) == {self.training}
            # self.modify_summary_log.append(self.training)
            return super().modify_summary(summary)

    ds_train = [0., 1., 2.]
    ds_valid = [0., 1.]
    with tempfile.TemporaryDirectory() as tmp_dir:
        optimizer = pt.optimizer.Adam()
        model = Model([1, 2, 3, 4, 5] * 10, tmp_dir, optimizer)
        trainer = pt.Trainer(
            model, tmp_dir, optimizer, stop_trigger=(10, 'epoch'),
            summary_trigger=(1, 'epoch')
        )
        trainer.register_validation_hook(ds_valid)
        trainer.train(ds_train)
Exemplo n.º 4
0
def test_validation_hook_create_snapshot_flag():
    if sys.platform.startswith('win'):
        pytest.skip('this doctest does not work on Windows, '
                    'training is not possible on Windows due to symlinks being unavailable')
    class Model(DummyModel):

        def __init__(self, validation_losses, exp_dir, optimizer):
            super().__init__(validation_losses, exp_dir, optimizer)
            self.train_create_snapshot_log = []
            self.validation_create_snapshot_log = []

        def review(self, example, output):
            if self.training:
                self.train_create_snapshot_log.append(self.create_snapshot)
            else:
                self.validation_create_snapshot_log.append(self.create_snapshot)
            return super().review(example, output)

    ds_train = [0., 1., 2.]
    ds_valid = [0., 1.]
    with tempfile.TemporaryDirectory() as tmp_dir:
        optimizer = pt.optimizer.Adam()
        model = Model([1, 2, 3, 4, 5] * 10, tmp_dir, optimizer)
        trainer = pt.Trainer(
            model, tmp_dir, optimizer, stop_trigger=(10, 'epoch'),
            summary_trigger=(1, 'epoch')
        )
        trainer.register_validation_hook(ds_valid)
        trainer.train(ds_train)
    assert model.train_create_snapshot_log == [True, False, False] * 10
    # The validation is executed num_epochs + 1 times, so 11
    assert model.validation_create_snapshot_log == [True, False] * 11
Exemplo n.º 5
0
def test_validation_hook_create_snapshot_flag():
    class Model(DummyModel):

        def __init__(self, validation_losses, exp_dir, optimizer):
            super().__init__(validation_losses, exp_dir, optimizer)
            self.train_create_snapshot_log = []
            self.validation_create_snapshot_log = []

        def review(self, example, output):
            if self.training:
                self.train_create_snapshot_log.append(self.create_snapshot)
            else:
                self.validation_create_snapshot_log.append(self.create_snapshot)
            return super().review(example, output)

    ds_train = [0., 1., 2.]
    ds_valid = [0., 1.]
    with tempfile.TemporaryDirectory() as tmp_dir:
        optimizer = pt.optimizer.Adam()
        model = Model([1, 2, 3, 4, 5] * 10, tmp_dir, optimizer)
        trainer = pt.Trainer(
            model, tmp_dir, optimizer, stop_trigger=(10, 'epoch'),
            summary_trigger=(1, 'epoch')
        )
        trainer.register_validation_hook(ds_valid)
        trainer.train(ds_train)
    assert model.train_create_snapshot_log == [True, False, False] * 10
    # The validation is executed num_epochs + 1 times, so 11
    assert model.validation_create_snapshot_log == [True, False] * 11
Exemplo n.º 6
0
def test_backoff():
    ds = [0]
    with tempfile.TemporaryDirectory() as tmp_dir:
        optimizer = pt.optimizer.Adam()
        model = DummyModel([3, 2, 1, 0, 1, 1, 1, 1, 1, 1], tmp_dir, optimizer)
        trainer = pt.Trainer(
            model, tmp_dir, optimizer, stop_trigger=(10, 'epoch')
        )
        trainer.register_validation_hook(
            ds, max_checkpoints=None,
            n_back_off=1, back_off_patience=2, early_stopping_patience=2
        )
        trainer.train(ds)
    assert model.ckpt_log == [
        [],
        ['ckpt_0.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_4.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_4.pth', 'ckpt_5.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_4.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
        ['ckpt_0.pth', 'ckpt_1.pth', 'ckpt_2.pth', 'ckpt_3.pth', 'ckpt_4.pth', 'ckpt_5.pth', 'ckpt_best_loss.pth', 'ckpt_latest.pth'],
    ]
    assert model.lr_log == 7*[0.001]+3*[0.0001]
Exemplo n.º 7
0
def test_multiple_optimizers():
    tr_dataset, dataset_dt = get_datasets()

    model = AE()
    optimizers = {'enc': pt.optimizer.Adam(), 'dec': pt.optimizer.Adam()}
    with tempfile.TemporaryDirectory() as tmp_dir:
        trainer = pt.Trainer(model, storage_dir=tmp_dir, optimizer=optimizers)

        with assert_dir_unchanged_after_context(tmp_dir):
            trainer.test_run(tr_dataset, dataset_dt)
Exemplo n.º 8
0
def test_single_grad_check():
    tr_dataset, dt_dataset = get_datasets()

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)
        t = pt.Trainer(Model(),
                       optimizer=pt.optimizer.Adam(),
                       storage_dir=tmp_dir,
                       stop_trigger=(2, 'epoch'))
        t.test_run(tr_dataset, dt_dataset)

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)
        t = pt.Trainer(ZeroGradModel(),
                       optimizer=pt.optimizer.Adam(),
                       storage_dir=tmp_dir,
                       stop_trigger=(2, 'epoch'))
        # AssertionError: The loss of the model did not change between two validations.
        with pytest.raises(AssertionError):
            t.test_run(tr_dataset, dt_dataset)
Exemplo n.º 9
0
def train():
    model = SimpleMaskEstimator(513)
    print(f'Simple training for the following model: {model}')
    database = Chime3()
    train_ds = get_train_ds(database)
    validation_ds = get_validation_ds(database)
    trainer = pt.Trainer(model, STORAGE_ROOT / 'simple_mask_estimator',
                         optimizer=pt.train.optimizer.Adam(),
                         stop_trigger=(int(1e5), 'iteration'))
    trainer.test_run(train_ds, validation_ds)
    trainer.register_validation_hook(validation_ds)
    trainer.train(train_ds)
Exemplo n.º 10
0
def test_single_model_dir_unchanged():
    tr_dataset, dt_dataset = get_datasets()
    model = Model()

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)
        t = pt.Trainer(model,
                       optimizer=pt.optimizer.Adam(),
                       storage_dir=tmp_dir,
                       stop_trigger=(2, 'epoch'))
        with assert_dir_unchanged_after_context(tmp_dir):
            t.test_run(tr_dataset, dt_dataset)
Exemplo n.º 11
0
def test_single_model_with_back_off_validation():
    tr_dataset, dt_dataset = get_datasets()
    model = Model()

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)
        t = pt.Trainer(
            model, optimizer=pt.optimizer.Adam(),
            storage_dir=tmp_dir, stop_trigger=(2, 'epoch')
        )
        t.register_validation_hook(dt_dataset, n_back_off=4,
                                   back_off_patience=5)
        with assert_dir_unchanged_after_context(tmp_dir):
            t.test_run(tr_dataset, dt_dataset)
Exemplo n.º 12
0
def test_multiple_optimizers():
    if sys.platform.startswith('win'):
        pytest.skip(
            'this doctest does not work on Windows, '
            'training is not possible on Windows due to symlinks being unavailable'
        )
    tr_dataset, dataset_dt = get_datasets()

    model = AE()
    optimizers = {'enc': pt.optimizer.Adam(), 'dec': pt.optimizer.Adam()}
    with tempfile.TemporaryDirectory() as tmp_dir:
        trainer = pt.Trainer(model, storage_dir=tmp_dir, optimizer=optimizers)

        with assert_dir_unchanged_after_context(tmp_dir):
            trainer.test_run(tr_dataset, dataset_dt)
Exemplo n.º 13
0
def test_single_model_dir_unchanged():
    if sys.platform.startswith('win'):
        pytest.skip(
            'this doctest does not work on Windows, '
            'training is not possible on Windows due to symlinks being unavailable'
        )
    tr_dataset, dt_dataset = get_datasets()
    model = Model()

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)
        t = pt.Trainer(model,
                       optimizer=pt.optimizer.Adam(),
                       storage_dir=tmp_dir,
                       stop_trigger=(2, 'epoch'))
        with assert_dir_unchanged_after_context(tmp_dir):
            t.test_run(tr_dataset, dt_dataset)
Exemplo n.º 14
0
def train(storage_dir, database_json):
    model = SimpleMaskEstimator(513)
    print(f'Simple training for the following model: {model}')
    database = JsonDatabase(database_json)
    train_dataset = get_train_dataset(database)
    validation_dataset = get_validation_dataset(database)
    trainer = pt.Trainer(model,
                         storage_dir,
                         optimizer=pt.train.optimizer.Adam(),
                         stop_trigger=(int(1e5), 'iteration'))
    trainer.test_run(train_dataset, validation_dataset)
    trainer.register_validation_hook(validation_dataset,
                                     n_back_off=5,
                                     lr_update_factor=1 / 10,
                                     back_off_patience=1,
                                     early_stopping_patience=None)
    trainer.train(train_dataset)
Exemplo n.º 15
0
def test_summary_hook_create_snapshot_flag():
    class Model(DummyModel):

        def __init__(self, validation_losses, exp_dir, optimizer):
            super().__init__(validation_losses, exp_dir, optimizer)
            self.create_snapshot_log = []

        def review(self, example, output):
            self.create_snapshot_log.append(self.create_snapshot)
            return super().review(example, output)

    ds = [0., 1., 2.]
    with tempfile.TemporaryDirectory() as tmp_dir:
        optimizer = pt.optimizer.Adam()
        model = Model([1, 2, 3] * 10, tmp_dir, optimizer)
        trainer = pt.Trainer(
            model, tmp_dir, optimizer, stop_trigger=(10, 'epoch'),
            summary_trigger=(1, 'epoch')
        )
        trainer.train(ds)
    assert model.create_snapshot_log == [True, False, False] * 10
Exemplo n.º 16
0
def test_released_tensors():
    import gc
    gc.collect()

    tr_dataset, dt_dataset = get_dataset()
    tr_dataset = tr_dataset[:2]
    dt_dataset = dt_dataset[:2]

    class ReleaseTestHook(pt.train.hooks.Hook):
        def get_all_tensors(self):
            import gc
            tensors = []
            for obj in gc.get_objects():
                if isinstance(obj, torch.Tensor):
                    tensors.append(obj)
            return tensors

        def get_all_parameters(self, trainer):
            return list(trainer.model.parameters())

        def get_all_optimizer_tensors(self, trainer):
            def get_tensors(obj):
                if isinstance(obj, (dict, tuple, list)):
                    if isinstance(obj, dict):
                        obj = obj.values()
                    return list(itertools.chain(*[get_tensors(o)
                                                  for o in obj]))
                else:
                    if isinstance(obj, torch.Tensor):
                        return [obj]
                    else:
                        return []

            return get_tensors(trainer.optimizer.optimizer.state)

        @classmethod
        def show_referrers_type(cls, obj, depth, ignore=list()):
            # Debug function to get all references to an object and the
            # references to the references up to a depth of `depth`.
            import gc
            import textwrap
            import inspect
            l = []
            if depth > 0:
                referrers = gc.get_referrers(obj)
                for o in referrers:
                    if not any({o is i for i in ignore}):
                        for s in cls.show_referrers_type(o,
                                                         depth - 1,
                                                         ignore=ignore +
                                                         [referrers, o, obj]):
                            l.append(textwrap.indent(s, '  '))

            if inspect.isframe(obj):
                frame_info = inspect.getframeinfo(obj, context=1)
                if frame_info.function == 'show_referrers_type':
                    pass
                else:
                    info = f' {frame_info.function}, {frame_info.filename}:{frame_info.lineno}'
                    l.append(str(type(obj)) + str(info))
            else:
                l.append(str(type(obj)) + str(obj)[:200].replace('\n', ' '))
            return l

        def pre_step(self, trainer: 'pt.Trainer'):
            all_tensors = self.get_all_tensors()
            parameters = self.get_all_parameters(trainer)
            optimizer_tensors = self.get_all_optimizer_tensors(trainer)

            for p in all_tensors:
                if 'grad_fn' in repr(p) or 'grad_fn' in str(p):
                    txt = "\n".join(self.show_referrers_type(p, 2))
                    raise AssertionError(
                        'Found a tensor that has a grad_fn\n\n' + txt)

            summary = [
                t.shape for t in all_tensors
                if any([t is p for p in parameters])
            ]

            assert len(
                all_tensors) == len(parameters) + len(optimizer_tensors), (
                    f'pre_step\n'
                    f'{summary}\n'
                    f'all_tensors: {len(all_tensors)}\n'
                    f'{all_tensors}\n'
                    f'parameters: {len(parameters)}\n'
                    f'{parameters}'
                    f'optimizer_tensors: {len(optimizer_tensors)}\n'
                    f'{optimizer_tensors}\n')

        def post_step(self, trainer: 'pt.Trainer', example, model_output,
                      review):
            all_tensors = self.get_all_tensors()
            parameters = list(trainer.model.parameters())
            assert len(all_tensors) > len(parameters), ('post_step',
                                                        all_tensors,
                                                        parameters)

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)

        t = pt.Trainer(
            Model(),
            optimizer=pt.optimizer.Adam(),
            storage_dir=str(tmp_dir),
            stop_trigger=(1, 'epoch'),
            summary_trigger=(1, 'epoch'),
            checkpoint_trigger=(1, 'epoch'),
        )
        t.register_validation_hook(validation_iterator=dt_dataset,
                                   max_checkpoints=None)
        t.register_hook(ReleaseTestHook())  # This hook will do the tests
        t.train(tr_dataset)
Exemplo n.º 17
0
def test_log_error_state():
    if sys.platform.startswith('win'):
        pytest.skip(
            'this doctest does not work on Windows, '
            'training is not possible on Windows due to symlinks being unavailable'
        )

    with tempfile.TemporaryDirectory() as tmp_dir:
        t = pt.Trainer(
            Model(),
            optimizer=pt.optimizer.Adam(),
            storage_dir=str(tmp_dir),
            stop_trigger=(1, 'epoch'),
            summary_trigger=(1, 'epoch'),
            checkpoint_trigger=(1, 'epoch'),
        )

        # Working example
        stdout = io.StringIO()
        r = t.log_error_state(
            {
                'file_name': 'simple data',
            },
            file=stdout,
        )
        stdout = stdout.getvalue()

        assert r == f'{tmp_dir}/log/error_state_{{file_name}}.pth', (r, stdout)
        assert stdout == '', stdout

        with torch.serialization._open_file_like(
                f'{tmp_dir}/log/error_state_file_name.pth',
                'rb') as opened_file:
            if LooseVersion(torch.__version__) >= '1.6':
                assert torch.serialization._is_zipfile(opened_file)
            else:
                assert not torch.serialization._is_zipfile(opened_file)

        # Broken example
        stdout = io.StringIO()
        func = lambda x: x
        r = t.log_error_state(
            {
                'file_name': 'simple data',
                'broken_data': {
                    'working': 'works',
                    'broken': func
                },
            },
            file=stdout,
        )
        stdout = stdout.getvalue()

        assert r == f'{tmp_dir}/log/error_state_{{file_name,broken_data}}.pth', r
        stdout = stdout.splitlines()
        assert len(stdout) == 1, stdout
        assert stdout[
            0] == f'Cannot pickle <function test_log_error_state.<locals>.<lambda> at 0x{id(func):x}>, replace it with a str.'

        assert torch.load(
            f'{tmp_dir}/log/error_state_file_name.pth') == 'simple data'
        assert torch.load(f'{tmp_dir}/log/error_state_broken_data.pth') == {
            'working':
            'works',
            'broken':
            f'<function test_log_error_state.<locals>.<lambda> at 0x{id(func):x}>'
        }

        # TCL reported that his code used `_legacy_save`.
        # Hence test that the `_legacy_save` works.
        torch_save = functools.partial(torch.save,
                                       _use_new_zipfile_serialization=False)
        with mock.patch('torch.save', torch_save):
            # Working example
            stdout = io.StringIO()
            r = t.log_error_state(
                {
                    'file_name': 'simple data',
                },
                file=stdout,
            )
            stdout = stdout.getvalue()

            assert r == f'{tmp_dir}/log/error_state_{{file_name}}.pth', (
                r, stdout)
            assert stdout == '', stdout

            with torch.serialization._open_file_like(
                    f'{tmp_dir}/log/error_state_file_name.pth',
                    'rb') as opened_file:
                assert not torch.serialization._is_zipfile(opened_file)
Exemplo n.º 18
0
def test_released_tensors():
    if sys.platform.startswith('win'):
        pytest.skip(
            'this doctest does not work on Windows, '
            'training is not possible on Windows due to symlinks being unavailable'
        )

    import gc
    gc.collect()

    tr_dataset, dt_dataset = get_dataset()
    tr_dataset = tr_dataset[:2]
    dt_dataset = dt_dataset[:2]

    class ReleaseTestHook(pt.train.hooks.Hook):
        def get_all_tensors(self):
            import gc
            tensors = []
            for obj in gc.get_objects():
                if isinstance(obj, torch.Tensor):
                    tensors.append(obj)
            return tensors

        def get_all_parameters(self, trainer):
            return list(trainer.model.parameters())

        def get_all_parameter_grads(self, trainer):
            return [
                p.grad for p in trainer.model.parameters()
                if p.grad is not None
            ]

        def get_all_optimizer_tensors(self, trainer):
            def get_tensors(obj):
                if isinstance(obj, (dict, tuple, list)):
                    if isinstance(obj, dict):
                        obj = obj.values()
                    return list(itertools.chain(*[get_tensors(o)
                                                  for o in obj]))
                else:
                    if isinstance(obj, torch.Tensor):
                        return [obj]
                    else:
                        return []

            return get_tensors(trainer.optimizer.optimizer.state)

        @classmethod
        def show_referrers_type(cls, obj, depth, ignore=list()):
            # Debug function to get all references to an object and the
            # references to the references up to a depth of `depth`.
            import gc
            import textwrap
            import inspect
            l = []
            if depth > 0:
                referrers = gc.get_referrers(obj)
                for o in referrers:
                    if not any({o is i for i in ignore}):
                        for s in cls.show_referrers_type(o,
                                                         depth - 1,
                                                         ignore=ignore +
                                                         [referrers, o, obj]):
                            l.append(textwrap.indent(s, ' ' * 4))

            if inspect.isframe(obj):
                frame_info = inspect.getframeinfo(obj, context=1)
                if frame_info.function == 'show_referrers_type':
                    pass
                else:
                    info = f' {frame_info.function}, {frame_info.filename}:{frame_info.lineno}'
                    l.append(f'Frame: {type(obj)} {info}')
            else:
                l.append(str(type(obj)) + str(obj)[:80].replace('\n', ' '))
            return l

        def pre_step(self, trainer: 'pt.Trainer'):
            all_tensors = self.get_all_tensors()
            parameters = self.get_all_parameters(trainer)

            # In torch 1.10 https://github.com/pytorch/pytorch/pull/56017 was
            # merged and grads are now visible for the garbage collector.
            grads = self.get_all_parameter_grads(trainer)

            optimizer_tensors = self.get_all_optimizer_tensors(trainer)

            for p in all_tensors:
                if 'grad_fn' in repr(p) or 'grad_fn' in str(p):
                    txt = "\n".join(self.show_referrers_type(p, 2))
                    raise AssertionError(
                        'Found a tensor that has a grad_fn\n\n' + txt)

            summary = [
                t.shape for t in all_tensors
                if any([t is p for p in parameters])
            ]

            import textwrap
            print(len(all_tensors), len(parameters), len(optimizer_tensors))

            assert len(all_tensors) == len(parameters) + len(
                optimizer_tensors) + len(grads), (
                    f'pre_step\n'
                    f'{summary}\n'
                    f'all_tensors: {len(all_tensors)}\n' + textwrap.indent(
                        "\n".join(map(str, all_tensors)), " " * 8) + f'\n'
                    f'parameters: {len(parameters)}\n' +
                    textwrap.indent("\n".join(map(str, parameters)), " " * 8) +
                    f'\n'
                    f'parameters: {len(grads)}\n' +
                    textwrap.indent("\n".join(map(str, grads)), " " * 8) +
                    f'\n'
                    f'optimizer_tensors: {len(optimizer_tensors)}\n' +
                    textwrap.indent("\n".join(map(str, optimizer_tensors)),
                                    " " * 8) + f'\n')

        def post_step(self, trainer: 'pt.Trainer', example, model_output,
                      review):
            all_tensors = self.get_all_tensors()
            parameters = list(trainer.model.parameters())
            assert len(all_tensors) > len(parameters), ('post_step',
                                                        all_tensors,
                                                        parameters)

    with tempfile.TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)

        t = pt.Trainer(
            Model(),
            optimizer=pt.optimizer.Adam(),
            storage_dir=str(tmp_dir),
            stop_trigger=(1, 'epoch'),
            summary_trigger=(1, 'epoch'),
            checkpoint_trigger=(1, 'epoch'),
        )
        t.register_validation_hook(validation_iterator=dt_dataset,
                                   max_checkpoints=None)
        t.register_hook(ReleaseTestHook())  # This hook will do the tests
        t.train(tr_dataset)