def test_negative_batch_index(self): epoch_idx = 3 batch_idx = -13 batches_per_epoch = 20 with self.assertRaises(ValueError): training_utils.epoch_cursor(epoch_idx=epoch_idx, batch_idx=batch_idx, batches_per_epoch=batches_per_epoch)
def test_batch_index_exceeds_batches_per_epoch(self): epoch_idx = 3 batch_idx = 130 batches_per_epoch = 20 with self.assertRaises(ValueError): training_utils.epoch_cursor(epoch_idx=epoch_idx, batch_idx=batch_idx, batches_per_epoch=batches_per_epoch)
def test_invalid_number_of_batches_per_epoch(self): epoch_idx = 3 batch_idx = 13 batches_per_epoch = -20 with self.assertRaises(ValueError): training_utils.epoch_cursor(epoch_idx=epoch_idx, batch_idx=batch_idx, batches_per_epoch=batches_per_epoch)
def test_batches_per_epoch_larger_than_multiplier(self, cursor_multiplier_mock): cursor_multiplier_mock.EPOCH_CURSOR_MULTIPLIER = 1000 epoch_idx = 3 batch_idx = 13 batches_per_epoch = 2000 with self.assertRaises(ValueError): training_utils.epoch_cursor(epoch_idx=epoch_idx, batch_idx=batch_idx, batches_per_epoch=batches_per_epoch)
def test_middle_batch_success(self, cursor_multiplier_mock): cursor_multiplier_mock.EPOCH_CURSOR_MULTIPLIER = 1000 epoch_idx = 3 batch_idx = 13 batches_per_epoch = 20 expected_cursor = 3700 cursor = training_utils.epoch_cursor( epoch_idx=epoch_idx, batch_idx=batch_idx, batches_per_epoch=batches_per_epoch) self.assertEqual(cursor, expected_cursor)
def _run_epoch(self, epoch_idx): pbar_desc = '{} Training epoch {}/{}'.format( PROGRESS_BAR_PREFIX, epoch_idx + 1, config.ExperimentConfig.NUM_EPOCHS) pbar = tqdm(self._training_engine.train_epoch(), total=self._training_engine.batches_per_epoch, desc=pbar_desc, disable=DISABLE_PROGRESS_BAR) for batch_idx, loss, train_summary in pbar: cursor = train_utils.epoch_cursor( epoch_idx=epoch_idx, batch_idx=batch_idx, batches_per_epoch=self._training_engine.batches_per_epoch) self._logging_engine.log_summary(summary=train_summary, epoch_cursor=cursor) # Save the checkpoint on disk. self._training_engine.save() # Update the training status. self._training_status[_CNST.LATEST_TRAINED_KEY] = epoch_idx self._save_training_status()