예제 #1
0
    def test_last_step(self):
        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    start_step=Step.from_epoch(2, 11, len(self.train_loader)),
                    end_step=Step.from_epoch(3, 0, len(self.train_loader)))

        end_state = TestStandardCallbacks.get_state(self.model)

        # Check that final state has been saved.
        end_loc = paths.model(self.root,
                              Step.from_epoch(3, 0, len(self.train_loader)))
        self.assertTrue(os.path.exists(end_loc))

        # Check that the final state that is saved matches the final state of the network.
        self.model.load_state_dict(torch.load(end_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(end_state, saved_state)

        # Check that the logger has the right number of entries.
        self.assertTrue(os.path.exists(paths.logger(self.root)))
        logger = MetricLogger.create_from_file(self.root)
        self.assertEqual(len(logger.get_data('train_loss')), 1)
        self.assertEqual(len(logger.get_data('test_loss')), 1)
        self.assertEqual(len(logger.get_data('train_accuracy')), 1)
        self.assertEqual(len(logger.get_data('test_accuracy')), 1)

        # Check that the checkpoint file exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))
예제 #2
0
    def test_file_operations(self):
        logger = TestMetricLogger.create_logger()
        save_loc = os.path.join(self.root, 'temp_logger')

        logger.save(save_loc)
        logger2 = MetricLogger.create_from_file(save_loc)

        self.assertEqual(logger.get_data('train_accuracy'), logger2.get_data('train_accuracy'))
        self.assertEqual(logger.get_data('test_accuracy'), logger2.get_data('test_accuracy'))
        self.assertEqual(str(logger), str(logger2))
예제 #3
0
    def test_end_to_end(self):
        init_loc = paths.model(self.root, Step.zero(len(self.train_loader)))
        end_loc = paths.model(self.root,
                              Step.from_epoch(3, 0, len(self.train_loader)))

        init_state = TestStandardCallbacks.get_state(self.model)

        train.train(self.hparams.training_hparams,
                    self.model,
                    self.train_loader,
                    self.root,
                    callbacks=self.callbacks,
                    start_step=Step.from_epoch(0, 0, len(self.train_loader)),
                    end_step=Step.from_epoch(3, 0, len(self.train_loader)))

        end_state = TestStandardCallbacks.get_state(self.model)

        # Check that final state has been saved.
        self.assertTrue(os.path.exists(init_loc))
        self.assertTrue(os.path.exists(end_loc))

        # Check that the checkpoint file still exists.
        self.assertTrue(os.path.exists(paths.checkpoint(self.root)))

        # Check that the initial and final states match those that were saved.
        self.model.load_state_dict(torch.load(init_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(init_state, saved_state)

        self.model.load_state_dict(torch.load(end_loc))
        saved_state = TestStandardCallbacks.get_state(self.model)
        self.assertStateEqual(end_state, saved_state)

        # Check that the logger has the right number of entries.
        self.assertTrue(os.path.exists(paths.logger(self.root)))
        logger = MetricLogger.create_from_file(self.root)
        self.assertEqual(len(logger.get_data('train_loss')), 4)
        self.assertEqual(len(logger.get_data('test_loss')), 4)
        self.assertEqual(len(logger.get_data('train_accuracy')), 4)
        self.assertEqual(len(logger.get_data('test_accuracy')), 4)