Exemple #1
0
    def test_history_no_epochs_batches(self, has_epoch, epoch_slice):
        h = History()
        if has_epoch:
            h.new_epoch()

        # Expect a list of zero epochs since 'batches' always exists
        assert h[epoch_slice, 'batches'] == []
        assert h[epoch_slice, 'batches', -1] == []
Exemple #2
0
 def test_history_jagged_batches(self):
     h = History()
     for num_batch in (1, 2):
         h.new_epoch()
         for _ in range(num_batch):
             h.new_batch()
     # Make sure we can access this batch
     assert h[-1, 'batches', 1] == {}
Exemple #3
0
    def test_history_no_epochs_key(self, has_epoch, epoch_slice):
        h = History()
        if has_epoch:
            h.new_epoch()

        # Expect KeyError since the key was not found in any epochs
        with pytest.raises(KeyError):
            # pylint: disable=pointless-statement
            h[epoch_slice, 'foo']
        with pytest.raises(KeyError):
            # pylint: disable=pointless-statement
            h[epoch_slice, ['foo', 'bar']]
Exemple #4
0
    def test_history_key_in_other_epoch(self):
        h = History()
        for has_valid in (True, False):
            h.new_epoch()
            h.new_batch()
            h.record_batch('train_loss', 1)
            if has_valid:
                h.new_batch()
                h.record_batch('valid_loss', 2)

        with pytest.raises(KeyError):
            # pylint: disable=pointless-statement
            h[-1, 'batches', -1, 'valid_loss']
Exemple #5
0
    def test_history_retrieve_empty_list(self, value, check_warn, recwarn):
        h = History()
        h.new_epoch()
        h.record('foo', value)
        h.new_batch()
        h.record_batch('batch_foo', value)

        # Make sure we can access our object
        assert h[-1, 'foo'] is value
        assert h[-1, 'batches', -1, 'batch_foo'] is value

        # There should be no warning about comparison to an empty ndarray
        if check_warn:
            assert not recwarn.list
Exemple #6
0
 def history(self):
     """Return a history filled with epoch and batch data."""
     h = History()
     for num_epoch in range(self.test_epochs):
         h.new_epoch()
         h.record('duration', 1)
         h.record('total_loss', num_epoch + self.test_batches)
         if num_epoch == 2:
             h.record('extra', 42)
         for num_batch in range(self.test_batches):
             h.new_batch()
             h.record_batch('loss', num_epoch + num_batch)
             if num_batch % 2 == 0 and (num_epoch + 1) != self.test_epochs:
                 h.record_batch('extra_batch', 23)
     return h
Exemple #7
0
    def test_average_honors_weights(self, train_loss, history):
        """The batches may have different batch sizes, which is why it
        necessary to honor the batch sizes. Here we use different
        batch sizes to verify this.

        """
        from skorch.history import History

        history = History()
        history.new_epoch()
        history.new_batch()
        history.record_batch('train_loss', 10)
        history.record_batch('train_batch_size', 1)
        history.new_batch()
        history.record_batch('train_loss', 40)
        history.record_batch('train_batch_size', 2)

        net = Mock(history=history)
        train_loss.on_epoch_end(net)

        assert history[0, 'train_loss'] == 30
Exemple #8
0
    def test_average_honors_weights(self, train_loss, history):
        """The batches may have different batch sizes, which is why it
        necessary to honor the batch sizes. Here we use different
        batch sizes to verify this.

        """
        from skorch.history import History

        history = History()
        history.new_epoch()
        history.new_batch()
        history.record_batch('train_loss', 10)
        history.record_batch('train_batch_size', 1)
        history.new_batch()
        history.record_batch('train_loss', 40)
        history.record_batch('train_batch_size', 2)

        net = Mock(history=history)
        train_loss.on_epoch_end(net)

        assert history[0, 'train_loss'] == 30