Ejemplo n.º 1
0
    def test_epoch_statistics(self):
        batch1 = tpm_ts.BatchStatistics(1, 1, 1)
        batch2 = tpm_ts.BatchStatistics(2, 2, 2)
        batch3 = tpm_ts.BatchStatistics(3, 3, 3)
        batch4 = tpm_ts.BatchStatistics(4, 4, 4)

        epoch_stats = tpm_ts.EpochStatistics(1)
        epoch_stats.add_batch(batch1)
        epoch_stats.add_batch(batch2)
        epoch_stats.add_batch([batch3, batch4])

        batch_stats_list = epoch_stats.get_batch_stats()
        for batch_idx, batch_stats in enumerate(batch_stats_list):
            expected_val = batch_idx + 1
            self.assertEqual(batch_stats.get_batch_num(), expected_val)
            self.assertEqual(batch_stats.get_batch_train_acc(), expected_val)
            self.assertEqual(batch_stats.get_batch_train_loss(), expected_val)
        self.assertEqual(epoch_stats.get_epoch_num(), 1)
Ejemplo n.º 2
0
    def test_batch_statistics(self):
        batch_num = 1
        batch_train_acc = 50
        batch_train_loss = -4
        batch_statistics = tpm_ts.BatchStatistics(batch_num, batch_train_acc,
                                                  batch_train_loss)
        self.assertEqual(batch_statistics.get_batch_num(), batch_num)
        self.assertEqual(batch_statistics.get_batch_train_acc(),
                         batch_train_acc)
        self.assertEqual(batch_statistics.get_batch_train_loss(),
                         batch_train_loss)

        self.assertRaises(ValueError, batch_statistics.set_batch_train_acc,
                          150)
        self.assertRaises(ValueError, batch_statistics.set_batch_train_acc,
                          -50)
Ejemplo n.º 3
0
    def test_training_statistics(self):
        batch1 = tpm_ts.BatchStatistics(1, 1, 1)
        batch2 = tpm_ts.BatchStatistics(2, 2, 2)
        batch3 = tpm_ts.BatchStatistics(3, 3, 3)
        batch4 = tpm_ts.BatchStatistics(4, 4, 4)
        batch5 = tpm_ts.BatchStatistics(5, 5, 5)
        batch6 = tpm_ts.BatchStatistics(6, 6, 6)

        epoch1_stats = tpm_ts.EpochStatistics(1)
        epoch2_stats = tpm_ts.EpochStatistics(2)
        epoch3_stats = tpm_ts.EpochStatistics(3)
        epoch1_stats.add_batch([batch1, batch2])
        epoch2_stats.add_batch([batch3, batch4])
        epoch3_stats.add_batch([batch5, batch6])

        training_stats = tpm_ts.TrainingRunStatistics()
        training_stats.add_epoch(epoch1_stats)
        training_stats.add_epoch([epoch2_stats, epoch3_stats])
        training_stats.set_final_train_acc(1)
        training_stats.set_final_train_loss(1)
        training_stats.set_final_val_combined_acc(1)
        training_stats.set_final_val_combined_loss(1)
        training_stats.set_final_clean_data_test_acc(1)
        training_stats.set_final_triggered_data_test_acc(1)

        summary_dict = training_stats.get_summary()
        self.assertEqual(summary_dict['final_train_acc'], 1)
        self.assertEqual(summary_dict['final_train_loss'], 1)
        self.assertEqual(summary_dict['final_combined_val_acc'], 1)
        self.assertEqual(summary_dict['final_combined_val_loss'], 1)
        self.assertEqual(summary_dict['final_clean_data_test_acc'], 1)
        self.assertEqual(summary_dict['final_triggered_data_test_acc'], 1)

        self.assertRaises(ValueError, training_stats.set_final_train_acc, 150)
        self.assertRaises(ValueError, training_stats.set_final_train_acc, -50)
        self.assertRaises(ValueError,
                          training_stats.set_final_val_combined_acc, 150)
        self.assertRaises(ValueError,
                          training_stats.set_final_val_combined_acc, -50)
        self.assertRaises(ValueError,
                          training_stats.set_final_clean_data_test_acc, 150)
        self.assertRaises(ValueError,
                          training_stats.set_final_clean_data_test_acc, -50)
        self.assertRaises(ValueError,
                          training_stats.set_final_triggered_data_test_acc,
                          150)
        self.assertRaises(ValueError,
                          training_stats.set_final_triggered_data_test_acc,
                          -50)

        # ensure data is maintained over epochs
        epoch_stats = training_stats.get_epochs_stats()
        batch_idx = 1
        for epoch_num, epoch in enumerate(epoch_stats):
            actual_epoch_num = epoch_num + 1
            self.assertEqual(epoch.get_epoch_num(), actual_epoch_num)
            batch_stats_list = epoch.get_batch_stats()
            for batch_stats in batch_stats_list:
                expected_val = batch_idx
                self.assertEqual(batch_stats.get_batch_num(), expected_val)
                self.assertEqual(batch_stats.get_batch_train_acc(),
                                 expected_val)
                self.assertEqual(batch_stats.get_batch_train_loss(),
                                 expected_val)

                batch_idx += 1