def test_history_no_epochs_batches(self, has_epoch, epoch_slice): h = History() if has_epoch: h.new_epoch() # Expect a list of zero epochs since 'batches' always exists assert h[epoch_slice, 'batches'] == [] assert h[epoch_slice, 'batches', -1] == []
def test_history_jagged_batches(self): h = History() for num_batch in (1, 2): h.new_epoch() for _ in range(num_batch): h.new_batch() # Make sure we can access this batch assert h[-1, 'batches', 1] == {}
def test_history_no_epochs_key(self, has_epoch, epoch_slice): h = History() if has_epoch: h.new_epoch() # Expect KeyError since the key was not found in any epochs with pytest.raises(KeyError): # pylint: disable=pointless-statement h[epoch_slice, 'foo'] with pytest.raises(KeyError): # pylint: disable=pointless-statement h[epoch_slice, ['foo', 'bar']]
def test_history_key_in_other_epoch(self): h = History() for has_valid in (True, False): h.new_epoch() h.new_batch() h.record_batch('train_loss', 1) if has_valid: h.new_batch() h.record_batch('valid_loss', 2) with pytest.raises(KeyError): # pylint: disable=pointless-statement h[-1, 'batches', -1, 'valid_loss']
def test_history_retrieve_empty_list(self, value, check_warn, recwarn): h = History() h.new_epoch() h.record('foo', value) h.new_batch() h.record_batch('batch_foo', value) # Make sure we can access our object assert h[-1, 'foo'] is value assert h[-1, 'batches', -1, 'batch_foo'] is value # There should be no warning about comparison to an empty ndarray if check_warn: assert not recwarn.list
def history(self): """Return a history filled with epoch and batch data.""" h = History() for num_epoch in range(self.test_epochs): h.new_epoch() h.record('duration', 1) h.record('total_loss', num_epoch + self.test_batches) if num_epoch == 2: h.record('extra', 42) for num_batch in range(self.test_batches): h.new_batch() h.record_batch('loss', num_epoch + num_batch) if num_batch % 2 == 0 and (num_epoch + 1) != self.test_epochs: h.record_batch('extra_batch', 23) return h
def test_average_honors_weights(self, train_loss, history): """The batches may have different batch sizes, which is why it necessary to honor the batch sizes. Here we use different batch sizes to verify this. """ from skorch.history import History history = History() history.new_epoch() history.new_batch() history.record_batch('train_loss', 10) history.record_batch('train_batch_size', 1) history.new_batch() history.record_batch('train_loss', 40) history.record_batch('train_batch_size', 2) net = Mock(history=history) train_loss.on_epoch_end(net) assert history[0, 'train_loss'] == 30