def test_negative_batch_index(self):
     epoch_idx = 3
     batch_idx = -13
     batches_per_epoch = 20
     with self.assertRaises(ValueError):
         training_utils.epoch_cursor(epoch_idx=epoch_idx,
                                     batch_idx=batch_idx,
                                     batches_per_epoch=batches_per_epoch)
 def test_batch_index_exceeds_batches_per_epoch(self):
     epoch_idx = 3
     batch_idx = 130
     batches_per_epoch = 20
     with self.assertRaises(ValueError):
         training_utils.epoch_cursor(epoch_idx=epoch_idx,
                                     batch_idx=batch_idx,
                                     batches_per_epoch=batches_per_epoch)
 def test_invalid_number_of_batches_per_epoch(self):
     epoch_idx = 3
     batch_idx = 13
     batches_per_epoch = -20
     with self.assertRaises(ValueError):
         training_utils.epoch_cursor(epoch_idx=epoch_idx,
                                     batch_idx=batch_idx,
                                     batches_per_epoch=batches_per_epoch)
 def test_batches_per_epoch_larger_than_multiplier(self,
                                                   cursor_multiplier_mock):
     cursor_multiplier_mock.EPOCH_CURSOR_MULTIPLIER = 1000
     epoch_idx = 3
     batch_idx = 13
     batches_per_epoch = 2000
     with self.assertRaises(ValueError):
         training_utils.epoch_cursor(epoch_idx=epoch_idx,
                                     batch_idx=batch_idx,
                                     batches_per_epoch=batches_per_epoch)
    def test_middle_batch_success(self, cursor_multiplier_mock):
        cursor_multiplier_mock.EPOCH_CURSOR_MULTIPLIER = 1000
        epoch_idx = 3
        batch_idx = 13
        batches_per_epoch = 20
        expected_cursor = 3700

        cursor = training_utils.epoch_cursor(
            epoch_idx=epoch_idx,
            batch_idx=batch_idx,
            batches_per_epoch=batches_per_epoch)

        self.assertEqual(cursor, expected_cursor)
    def _run_epoch(self, epoch_idx):
        pbar_desc = '{} Training epoch {}/{}'.format(
            PROGRESS_BAR_PREFIX,
            epoch_idx + 1,
            config.ExperimentConfig.NUM_EPOCHS)
        pbar = tqdm(self._training_engine.train_epoch(),
                    total=self._training_engine.batches_per_epoch,
                    desc=pbar_desc,
                    disable=DISABLE_PROGRESS_BAR)
        for batch_idx, loss, train_summary in pbar:
            cursor = train_utils.epoch_cursor(
                epoch_idx=epoch_idx,
                batch_idx=batch_idx,
                batches_per_epoch=self._training_engine.batches_per_epoch)
            self._logging_engine.log_summary(summary=train_summary,
                                             epoch_cursor=cursor)

        # Save the checkpoint on disk.
        self._training_engine.save()

        # Update the training status.
        self._training_status[_CNST.LATEST_TRAINED_KEY] = epoch_idx
        self._save_training_status()