Exemple #1
0
def test_extensions_manager_state_dict_future_ppe_version():
    model_state_dict = object()
    optimizer_state_dict = object()
    max_epochs = 5
    iters_per_epoch = 4
    passed_iteration = 11

    manager = training.ExtensionsManager(
        {'model_name': _StateDictModel(state_dict=model_state_dict)},
        {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)},
        max_epochs,
        iters_per_epoch=iters_per_epoch,
    )

    for _ in range(passed_iteration):
        with manager.run_iteration():
            pass

    new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict)
    new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict)
    manager_2 = training.ExtensionsManager(
        {'model_name': new_model},
        {'optimizer_name': new_optimizer},
        max_epochs,
        iters_per_epoch=iters_per_epoch,
    )

    state_dict = manager.state_dict()
    state_dict['ppe_version'] = '23.0.0'
    with pytest.warns(UserWarning, match='version'):
        manager_2.load_state_dict(state_dict)
Exemple #2
0
def test_model_transformations_in_state_dict():
    model_state_dict = object()
    optimizer_state_dict = object()
    max_epochs = 5
    iters_per_epoch = 4
    model = Wrapper(_StateDictModel(state_dict=model_state_dict))
    manager = training.ExtensionsManager(
        model,
        _StateDictObj(state_dict=optimizer_state_dict),
        max_epochs,
        iters_per_epoch=iters_per_epoch,
        transform_model=lambda n, x: x.wrapper_module(),
    )

    state_dict = manager.state_dict()
    assert model.accessed

    new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict)
    new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict)
    new_manager = training.ExtensionsManager(
        Wrapper(new_model),
        new_optimizer,
        max_epochs,
        iters_per_epoch=iters_per_epoch,
        transform_model=lambda n, x: x.wrapper_module(),
    )
    new_manager.load_state_dict(state_dict)
    assert isinstance(new_manager.models['main'], _StateDictModel)
def test_extensions_manager_state_dict():
    model_state_dict = object()
    optimizer_state_dict = object()
    extension_state_dict = object()
    max_epochs = 5
    iters_per_epoch = 4
    passed_iteration = 11

    manager = training.ExtensionsManager(
        {'model_name': _StateDictModel(state_dict=model_state_dict)},
        {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)},
        max_epochs,
        iters_per_epoch=iters_per_epoch,
    )

    manager.extend(_StateDictExtension(state_dict=extension_state_dict),
                   name='extension_name')

    for it in range(passed_iteration):
        with manager.run_iteration():
            pass

    state_dict = manager.state_dict()

    assert state_dict == {
        '_start_iteration': passed_iteration,
        'models': {
            'model_name': model_state_dict
        },
        'optimizers': {
            'optimizer_name': optimizer_state_dict
        },
        'extensions': {
            'extension_name': {
                'extension': extension_state_dict,
                'trigger': {
                    '_previous_iteration': passed_iteration,
                    '_previous_epoch_detail':
                    passed_iteration / iters_per_epoch
                },
            }
        },
    }

    new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict)
    new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict)
    new_extension = _StateDictExtension(
        state_dict_to_be_loaded=extension_state_dict)
    new_manager = training.ExtensionsManager(
        {'model_name': new_model},
        {'optimizer_name': new_optimizer},
        max_epochs,
        iters_per_epoch=iters_per_epoch,
    )
    new_manager.extend(new_extension, name='extension_name')
    new_manager.load_state_dict(state_dict)
    assert new_model.called_load_state_dict == 1
    assert new_optimizer.called_load_state_dict == 1
    assert new_optimizer.called_load_state_dict == 1
Exemple #4
0
def test_trigger(iters_per_epoch, schedule, expected, finished, resume):
    trainer = training.ExtensionsManager({}, [],
                                         100,
                                         iters_per_epoch=iters_per_epoch)
    trigger = triggers.ManualScheduleTrigger(*schedule)

    _test_trigger(trainer, trigger, expected, finished)
def test_extensions_manager_extensions():
    model = nn.Module()
    optimizer = object()
    max_epochs = 5
    iters_per_epoch = 4
    manager = training.ExtensionsManager(
        {'model_name': model},
        {'optimizer_name': optimizer},
        max_epochs,
        iters_per_epoch=iters_per_epoch,
    )

    call_record = []
    init_record = []

    dummy5 = _DummyExtension(5, call_record, init_record)
    exts = [
        _DummyExtension(0, call_record, init_record),
        _DummyExtensionInitialize(1, call_record, init_record),
        _DummyExtension(2, call_record, init_record),
        _DummyExtensionInitialize(3, call_record, init_record),
        _DummyExtensionInitialize(4, call_record, init_record),
        lambda manager: dummy5(manager),
    ]

    manager.extend(exts[0], 'ext0', priority=2, call_before_training=True)
    manager.extend(exts[1], 'ext1', priority=1, call_before_training=False)
    manager.extend(exts[2], 'ext2', priority=3, call_before_training=False)
    manager.extend(exts[3], 'ext3', priority=0, call_before_training=True)
    manager.extend(exts[4], 'ext4', priority=4, call_before_training=True)
    manager.extend(exts[5], 'ext5', priority=-1, call_before_training=True)

    assert manager.get_extension('ext0') is exts[0]
    assert manager.get_extension('ext1') is exts[1]
    assert manager.get_extension('ext2') is exts[2]
    assert manager.get_extension('ext3') is exts[3]
    assert manager.get_extension('ext4') is exts[4]

    with pytest.raises(ValueError):
        manager.get_extension('ext10')

    for it in range(max_epochs * iters_per_epoch):
        call_record.clear()
        init_record.clear()

        with manager.run_iteration():
            assert manager.iteration == it

            if it == 0:
                assert call_record == [4, 0, 3, 5]
                assert init_record == [4, 1, 3]
            else:
                assert call_record == []
                assert init_record == []

            call_record.clear()
            init_record.clear()

        assert call_record == [4, 2, 0, 1, 3, 5]
        assert init_record == []
Exemple #6
0
def test_model_transformations(path):
    model_state_dict = object()
    optimizer_state_dict = object()
    max_epochs = 5
    iters_per_epoch = 4
    model = Wrapper(_StateDictModel(state_dict=model_state_dict))
    manager = training.ExtensionsManager(
        model,
        _StateDictObj(state_dict=optimizer_state_dict),
        max_epochs,
        iters_per_epoch=iters_per_epoch,
        out_dir=path,
    )

    snapshot = extensions.snapshot(
        filename='test', transform_models=lambda n, x: x.wrapper_module())
    snapshot(manager)

    assert model.accessed

    # Verify that autoload applies the transformation
    to_load = torch.load(os.path.join(path, 'test'))
    trainer = get_trainer(out_dir=path, state_to_load=to_load)
    snapshot = extensions.snapshot(
        filename='test',
        autoload=True,
        autoload_transform_models=lambda n, x: Wrapper(x))
    snapshot.initialize(trainer)
    assert isinstance(trainer.models['main'], Wrapper)
def test_extensions_manager_with_plain_model_and_optimizer():
    model_state_dict = object()
    optimizer_state_dict = object()
    max_epochs = 5
    iters_per_epoch = 4
    manager = training.ExtensionsManager(
        _StateDictModel(state_dict=model_state_dict),
        _StateDictObj(state_dict=optimizer_state_dict),
        max_epochs,
        iters_per_epoch=iters_per_epoch,
    )

    state_dict = manager.state_dict()

    assert state_dict == {
        '_start_execution': 0,
        '_start_iteration': 0,
        'models': {
            'main': model_state_dict
        },
        'optimizers': {
            'main': optimizer_state_dict
        },
        'extensions': {}
    }
Exemple #8
0
def get_trainer_with_mock_updater(path):
    epochs = 10  # FIXME
    model = _create_distributed_model()
    optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
    optimizers = {'main': optimizer}
    models = {'main': model}
    return training.ExtensionsManager(
        models, optimizers, epochs, iters_per_epoch=1, out_dir=path)
Exemple #9
0
def test_trigger(
        trigger_type, trigger_args, iters_per_epoch, accuracies, expected,
        resume):
    key = 'main/accuracy'
    manager = training.ExtensionsManager(
        {}, [], 100, iters_per_epoch=iters_per_epoch)
    trigger = trigger_type(key, *trigger_args)
    _test_trigger(
        manager, trigger, key, accuracies, expected)
def test_trigger(iters_per_epoch, interval, expected, resume):
    trainer = training.ExtensionsManager(
        {}, [], 100, iters_per_epoch=iters_per_epoch)
    trigger = triggers.IntervalTrigger(*interval)

    for e in expected:
        with trainer.run_iteration():
            pass
        assert trigger.may_fire(trainer.iteration, iters_per_epoch) == e
        assert trigger(trainer) == e
Exemple #11
0
def get_trainer_with_mock_updater(*, out_dir, state_to_load=None):
    model_state_dict = {}
    optimizer_state_dict = {}
    models = {'main': _StateDictModel(state_dict=model_state_dict)}
    optimizers = {'main': _StateDictObj(state_dict=optimizer_state_dict)}
    epochs = 10  # FIXME
    return training.ExtensionsManager(
        models, optimizers, epochs,
        iters_per_epoch=10,
        out_dir=out_dir)
Exemple #12
0
def test_get_trigger(iters_per_epoch, trigger_args, expected):
    trainer = training.ExtensionsManager({}, [],
                                         100,
                                         iters_per_epoch=iters_per_epoch)
    trigger = trigger_util.get_trigger(trigger_args)

    # before the first iteration, trigger should be False
    for it, e in enumerate([False] + expected):
        with trainer.run_iteration():
            assert trigger(trainer) == e
Exemple #13
0
def get_manager_model_optimizer():
    epochs = 3
    model = torch.nn.Linear(1, 3)
    optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
    optimizers = {'main': optimizer}
    models = {'main': model}
    manager = training.ExtensionsManager(
        models, optimizers, epochs,
        iters_per_epoch=4)
    manager.extend(FailOnNonNumber())
    return manager, model, optimizer
Exemple #14
0
def test_call_optimizers():
    m = torch.nn.Linear(5, 5)
    a = torch.ones(1, requires_grad=True)
    optimizer = torch.optim.SGD(lr=1.0, params=[a])
    manager = training.ExtensionsManager(
        m,
        optimizer,
        1,
        iters_per_epoch=1,
    )
    with manager.run_iteration(step_optimizers=['main']):
        a.grad = torch.tensor([2.0])
    assert torch.equal(a.detach(), torch.tensor([-1.]))
def test_manager_status_info():
    manager = training.ExtensionsManager(nn.Module(),
                                         object(),
                                         10,
                                         iters_per_epoch=4)
    manager.iteration = 9
    assert manager.iteration == 9
    assert manager.epoch == 2
    assert manager.epoch_detail == 2.25

    manager.iteration = 15
    assert manager.epoch == 3
    assert manager.epoch_detail == 3.75
Exemple #16
0
def test_resumed_trigger(iters_per_epoch, schedule, expected, finished,
                         resume):
    trainer = training.ExtensionsManager({}, [],
                                         100,
                                         iters_per_epoch=iters_per_epoch)
    trigger = triggers.ManualScheduleTrigger(*schedule)

    _test_trigger(trainer, trigger, expected[:resume], finished[:resume])

    state = trigger.state_dict()
    new_trigger = triggers.ManualScheduleTrigger(*schedule)
    new_trigger.load_state_dict(state)

    _test_trigger(trainer, new_trigger, expected[resume:], finished[resume:])
Exemple #17
0
def test_model_transformations():
    model_state_dict = object()
    optimizer_state_dict = object()
    max_epochs = 5
    iters_per_epoch = 4
    model = Wrapper(_StateDictModel(state_dict=model_state_dict))
    manager = training.ExtensionsManager(
        model,
        _StateDictObj(state_dict=optimizer_state_dict),
        max_epochs,
        iters_per_epoch=iters_per_epoch,
        transform_model=lambda n, x: x.wrapper_module(),
    )

    assert not isinstance(manager.models['main'], Wrapper)
    assert model.accessed
Exemple #18
0
def test_on_error():
    # Will fail when accesing the dummy optimizer
    optimizers = {'main': object()}
    trainer = training.ExtensionsManager(
        {}, optimizers, 1,
        iters_per_epoch=1,
        out_dir='.')
    filename = 'myfile-deadbeef.dat'

    snapshot = extensions.snapshot_object(trainer, filename,
                                          snapshot_on_error=True)
    trainer.extend(snapshot)
    assert not os.path.exists(filename)
    with pytest.raises(AttributeError):
        with trainer.run_iteration():
            pass
    assert not os.path.exists(filename)
Exemple #19
0
def test_resumed_trigger(
        trigger_type, trigger_args, iters_per_epoch, accuracies, expected,
        resume):
    key = 'main/accuracy'
    manager = training.ExtensionsManager(
        {}, [], 100, iters_per_epoch=iters_per_epoch)

    trigger = trigger_type(key, *trigger_args)
    _test_trigger(
        manager, trigger, key, accuracies[:resume],
        expected[:resume])

    state = trigger.state_dict()

    new_trigger = trigger_type(key, *trigger_args)
    new_trigger.load_state_dict(state)
    _test_trigger(
        manager, new_trigger, key, accuracies[resume:], expected[resume:])
def test_extensions_accessing_models_without_flag(priority):
    m = torch.nn.Linear(5, 5)
    a = torch.ones(1, requires_grad=True)
    optimizer = torch.optim.SGD(lr=1.0, params=[a])
    extension = _DummyExtension(0, [], [], True)
    extension.name = 'Dummy'
    extension.needs_model_state = False
    extension.trigger = (1, 'iteration')
    if priority is not None:
        extension.priority = priority
    manager = training.ExtensionsManager(m,
                                         optimizer,
                                         1,
                                         iters_per_epoch=5,
                                         extensions=[extension])
    while not manager.stop_trigger:
        with pytest.raises(RuntimeError):
            with manager.run_iteration():
                pass
Exemple #21
0
def test_model_transformations(path):
    model_state_dict = object()
    optimizer_state_dict = object()
    max_epochs = 5
    iters_per_epoch = 4
    model = Wrapper(_StateDictModel(state_dict=model_state_dict))
    manager = training.ExtensionsManager(
        model,
        _StateDictObj(state_dict=optimizer_state_dict),
        max_epochs,
        iters_per_epoch=iters_per_epoch,
        out_dir=path,
        transform_model=lambda n, x: x.wrapper_module(),
    )

    snapshot = extensions.snapshot(filename='test')
    snapshot(manager)

    assert model.accessed
def test_resumed_trigger(iters_per_epoch, interval, expected, resume):
    trainer = training.ExtensionsManager({}, [],
                                         100,
                                         iters_per_epoch=iters_per_epoch)
    trigger = triggers.IntervalTrigger(*interval)

    for e in expected[:resume]:
        with trainer.run_iteration():
            pass
        assert trigger(trainer) == e

    state = trigger.state_dict()
    new_trigger = triggers.IntervalTrigger(*interval)
    new_trigger.load_state_dict(state)

    for e in expected[resume:]:
        with trainer.run_iteration():
            pass
        assert new_trigger(trainer) == e
def test_needs_state_this_iteration():
    m = torch.nn.Linear(5, 5)
    a = torch.ones(1, requires_grad=True)
    optimizer = torch.optim.SGD(lr=1.0, params=[a])
    extension = _DummyExtension(0, [], [], True)
    extension.name = 'Dummy'
    extension.needs_model_state = True
    extension.trigger = (50, 'iteration')
    manager = training.ExtensionsManager(m,
                                         optimizer,
                                         1,
                                         iters_per_epoch=100,
                                         extensions=[extension])
    while not manager.stop_trigger:
        with manager.run_iteration():
            # iteration is always added 1 before calling
            # extensions
            if manager.iteration in (49, 99):
                assert manager.needs_state_this_iteration()
            else:
                assert not manager.needs_state_this_iteration()
Exemple #24
0
def test_deferred_iteration():
    m = torch.nn.Linear(5, 5)
    a = torch.ones(1, requires_grad=True)
    optimizer = torch.optim.SGD(lr=1.0, params=[a])
    call_record = []
    extension = _DummyExtension(0, call_record, [])
    extension.name = 'Dummy 0'
    extension.trigger = (1, 'iteration')
    extension.is_async = True
    extension2 = _DummyExtension(1, call_record, [])
    extension2.name = 'Dummy 1'
    extension2.trigger = (1, 'iteration')
    manager = training.ExtensionsManager(
        m,
        optimizer,
        1,
        iters_per_epoch=100,
        extensions=[extension, extension2]
    )
    for _ in range(5):
        with manager.run_iteration() as iter_handler:
            # iteration is always added 1 before calling
            # extensions
            iter_handler.defer()
    with manager.run_iteration():
        pass

    assert manager.iteration == 1
    assert manager.execution == 6
    assert call_record == [0] * 6 + [1]

    for _ in range(5):
        with manager.complete_iteration():
            pass

    assert manager.iteration == 6
    assert manager.execution == 6
    assert call_record == [0] * 6 + [1] * 6