コード例 #1
0
def test_metric_improvement():
    monitor = ImprovementMonitor(field='metric', improvement_on='gt')
    loop = ImprovingMetricLoop(callbacks=[monitor])
    expected_log = [
        Improvement(field='metric',
                    stage='val',
                    steps=1,
                    best=0.25,
                    last_best=-float('inf')),
        Improvement(field='metric',
                    stage='val',
                    steps=1,
                    best=0.5,
                    last_best=0.25),
        Improvement(field='metric',
                    stage='val',
                    steps=1,
                    best=0.75,
                    last_best=0.5),
    ]

    loop(3)
    assert loop.event_log == expected_log
コード例 #2
0
def test_loss_improvement():
    monitor = ImprovementMonitor(stagnant_after=2)
    loop = ImprovingLossLoop(callbacks=[monitor])

    expected_log = [
        Improvement(field='loss',
                    stage='val',
                    steps=1,
                    best=2.0,
                    last_best=float('inf')),
        Improvement(field='loss',
                    stage='val',
                    steps=1,
                    best=1.9,
                    last_best=2.0),
        Improvement(field='loss',
                    stage='val',
                    steps=1,
                    best=1.8,
                    last_best=1.9),
    ]
    loop(3)
    assert loop.event_log == expected_log
コード例 #3
0
 def on_stage_end(self, loop):
     if loop.stage == self.stage:
         this_value = self._get_value(loop)
         steps = loop.epoch - self._best_step
         event = None
         if self._is_improvement(this_value, self._last_best):
             event = Improvement(
                 self.field,
                 stage=self.stage,
                 steps=steps,
                 best=this_value,
                 last_best=self._last_best,
             )
             self._best_step = loop.epoch
             self._last_best = this_value
         elif steps > self.stagnant_after:
             event = Stagnation(
                 field=self.field, stage=self.stage, steps=steps, best=self._last_best
             )
         if event:
             loop.fire(event)
コード例 #4
0
)
def test_should_stop(callback, event):
    loop = DummyLoop(epoch=0)
    assert not loop.should_stop
    callback.on_event(loop, event)
    assert loop.should_stop
    callback.on_epoch_end(loop)
    assert loop._event_log == [EarlyStop(0)]


@pytest.mark.parametrize(
    'callback, event',
    [
        (
            EarlyStopping(field='loss', stage='val'),
            Improvement(field='loss', stage='val', steps=6, best=0.1, last_best=0.2),
        ),
        (EarlyStopping(patience=5), Stagnation(field='loss', stage='val', steps=3, best=0.1)),
        (EarlyStopping(patience=100), Stagnation(field='loss', stage='val', steps=5, best=0.1)),
        (EarlyStopping(stage='val'), Stagnation(field='loss', stage='train', steps=6, best=0.1)),
        (EarlyStopping(field='loss'), Stagnation(field='metric', stage='val', steps=6, best=0.1)),
        (
            EarlyStopping(field='loss', stage='val'),
            Stagnation(field='metric', stage='val', steps=6, best=0.1),
        ),
        (
            EarlyStopping(field='loss', stage='val'),
            Stagnation(field='loss', stage='train', steps=6, best=0.1),
        ),
    ],
)
コード例 #5
0
import pytest
from hearth.events import Improvement, Stagnation, CheckpointSaved, EarlyStop


@pytest.mark.parametrize(
    'event, msg',
    [
        (
            Stagnation(field='loss', stage='val', steps=2, best=0.01),
            'Stagnation[val.loss] stagnant for 2 steps no improvement from 0.0100.',
        ),
        (
            Improvement(field='metric', stage='val', steps=2, best=0.9, last_best=0.8),
            'Improvement[val.metric] improved from : 0.8000 to 0.9000 in 2 steps.',
        ),
        (CheckpointSaved('some/dir'), 'CheckpointSaved checkpoint saved to some/dir.'),
        (EarlyStop(1), 'EarlyStop triggered at epoch 1.'),
    ],
)
def test_logmsg(event, msg):
    assert event.logmsg() == msg
コード例 #6
0
def test_full_save_on_event(tmpdir):
    # make a file
    base = tmpdir.mkdir('models')
    model_dir = os.path.join(str(base), 'dummy')
    history = History({
        'epoch': 0,
        'lrs': {
            'group0': 0.001
        },
        'train': {
            'loss': 0.53,
            'metric': 0.85
        },
        'val': {
            'loss': 0.20,
            'metric': 0.93
        },
    })
    model = HearthModel()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

    callback = Checkpoint(model_dir=model_dir, )
    loop = DummyLoop(model=model, history=history, optimizer=optimizer)

    # sim registration
    callback.on_registration(loop)
    assert not callback._should_save

    # sim event
    event = Improvement(field='loss',
                        stage='val',
                        steps=1,
                        best=0.1,
                        last_best=0.2)
    callback.on_event(loop, event)
    assert callback._should_save

    # sim epoch end...
    callback.on_epoch_end(loop)
    # now model should be saved
    assert os.path.exists(os.path.join(model_dir, 'state.pt'))
    assert os.path.exists(os.path.join(model_dir, 'config.json'))
    # and should be able to be reloaded
    loaded_model = HearthModel.load(model_dir)
    assert loaded_model.config() == loop.model.config()
    assert (loaded_model.linear.weight == loop.model.linear.weight).all()
    assert (loaded_model.linear.bias == loop.model.linear.bias).all()

    # test saved history
    assert os.path.exists(os.path.join(model_dir, 'history.json'))
    # and should be able to be reloaded
    loaded_history = History.load(model_dir)
    assert loaded_history == loop.history

    # test saved optimizer state
    assert os.path.exists(os.path.join(model_dir, 'optimizer_state.pt'))
    loaded_opt_state = torch.load(os.path.join(model_dir,
                                               'optimizer_state.pt'))
    assert loaded_opt_state == loop.optimizer.state_dict()

    # and _should_save should be false
    assert not callback._should_save
    # and the loop should have seen an event...
    assert loop._event_log == [CheckpointSaved(model_dir)]
コード例 #7
0
        self._event_log.append(event)


class HearthModel(BaseModule):
    def __init__(self, in_feats: int = 3, out_feats: int = 6):
        super().__init__()
        self.linear = nn.Linear(in_feats, out_feats)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)


@pytest.mark.parametrize(
    'callback, event',
    [
        (Checkpoint('fake'), Improvement('blah', 'blah', 0, 0, 0)),
        (Checkpoint('fake', field='loss'), Improvement('loss', 'blah', 0, 0,
                                                       0)),
        (Checkpoint('fake', field='loss',
                    stage='val'), Improvement('loss', 'val', 0, 0, 0)),
        (Checkpoint('fake', stage='val'), Improvement('blah', 'val', 0, 0, 0)),
    ],
)
def test_should_save(callback, event):
    assert not callback._should_save
    callback.on_event(loop=None, event=event)
    assert callback._should_save


@pytest.mark.parametrize(
    'callback, event',