コード例 #1
0
    def test_history_no_epochs_batches(self, has_epoch, epoch_slice):
        h = History()
        if has_epoch:
            h.new_epoch()

        # Expect a list of zero epochs since 'batches' always exists
        assert h[epoch_slice, 'batches'] == []
        assert h[epoch_slice, 'batches', -1] == []
コード例 #2
0
 def test_history_jagged_batches(self):
     h = History()
     for num_batch in (1, 2):
         h.new_epoch()
         for _ in range(num_batch):
             h.new_batch()
     # Make sure we can access this batch
     assert h[-1, 'batches', 1] == {}
コード例 #3
0
    def test_history_no_epochs_key(self, has_epoch, epoch_slice):
        h = History()
        if has_epoch:
            h.new_epoch()

        # Expect KeyError since the key was not found in any epochs
        with pytest.raises(KeyError):
            # pylint: disable=pointless-statement
            h[epoch_slice, 'foo']
        with pytest.raises(KeyError):
            # pylint: disable=pointless-statement
            h[epoch_slice, ['foo', 'bar']]
コード例 #4
0
    def test_history_key_in_other_epoch(self):
        h = History()
        for has_valid in (True, False):
            h.new_epoch()
            h.new_batch()
            h.record_batch('train_loss', 1)
            if has_valid:
                h.new_batch()
                h.record_batch('valid_loss', 2)

        with pytest.raises(KeyError):
            # pylint: disable=pointless-statement
            h[-1, 'batches', -1, 'valid_loss']
コード例 #5
0
    def test_history_retrieve_empty_list(self, value, check_warn, recwarn):
        h = History()
        h.new_epoch()
        h.record('foo', value)
        h.new_batch()
        h.record_batch('batch_foo', value)

        # Make sure we can access our object
        assert h[-1, 'foo'] is value
        assert h[-1, 'batches', -1, 'batch_foo'] is value

        # There should be no warning about comparison to an empty ndarray
        if check_warn:
            assert not recwarn.list
コード例 #6
0
ファイル: test_history.py プロジェクト: cheungsingyi/skorch
 def history(self):
     """Return a history filled with epoch and batch data."""
     h = History()
     for num_epoch in range(self.test_epochs):
         h.new_epoch()
         h.record('duration', 1)
         h.record('total_loss', num_epoch + self.test_batches)
         if num_epoch == 2:
             h.record('extra', 42)
         for num_batch in range(self.test_batches):
             h.new_batch()
             h.record_batch('loss', num_epoch + num_batch)
             if num_batch % 2 == 0 and (num_epoch + 1) != self.test_epochs:
                 h.record_batch('extra_batch', 23)
     return h
コード例 #7
0
ファイル: test_scoring.py プロジェクト: magnumw/skorch
    def test_average_honors_weights(self, train_loss, history):
        """The batches may have different batch sizes, which is why it
        necessary to honor the batch sizes. Here we use different
        batch sizes to verify this.

        """
        from skorch.history import History

        history = History()
        history.new_epoch()
        history.new_batch()
        history.record_batch('train_loss', 10)
        history.record_batch('train_batch_size', 1)
        history.new_batch()
        history.record_batch('train_loss', 40)
        history.record_batch('train_batch_size', 2)

        net = Mock(history=history)
        train_loss.on_epoch_end(net)

        assert history[0, 'train_loss'] == 30
コード例 #8
0
ファイル: test_scoring.py プロジェクト: YangHaha11514/skorch
    def test_average_honors_weights(self, train_loss, history):
        """The batches may have different batch sizes, which is why it
        necessary to honor the batch sizes. Here we use different
        batch sizes to verify this.

        """
        from skorch.history import History

        history = History()
        history.new_epoch()
        history.new_batch()
        history.record_batch('train_loss', 10)
        history.record_batch('train_batch_size', 1)
        history.new_batch()
        history.record_batch('train_loss', 40)
        history.record_batch('train_batch_size', 2)

        net = Mock(history=history)
        train_loss.on_epoch_end(net)

        assert history[0, 'train_loss'] == 30