def test_sorting(self): logger = TestMetricLogger.create_logger() logger.add('train_accuracy', Step.from_iteration(5, 400), 0.9) logger.add('train_accuracy', Step.from_iteration(3, 400), 0.7) logger.add('train_accuracy', Step.from_iteration(4, 400), 0.8) self.assertEqual(logger.get_data('train_accuracy'), [(0, 0.5), (1, 0.6), (3, 0.7), (4, 0.8), (5, 0.9)])
def branch_function(self, retrain_d: hparams.DatasetHparams, retrain_t: hparams.TrainingHparams, start_at_step_zero: bool = False, transfer_learn: bool = False): # Get the mask and model. if transfer_learn: m = models.registry.load(self.level_root, self.lottery_desc.train_end_step, self.lottery_desc.model_hparams) else: m = models.registry.load(self.level_root, self.lottery_desc.train_start_step, self.lottery_desc.model_hparams) m = PrunedModel(m, Mask.load(self.level_root)) start_step = Step.from_iteration( 0 if start_at_step_zero else self.lottery_desc.train_start_step.iteration, datasets.registry.iterations_per_epoch(retrain_d)) train.standard_train(m, self.branch_root, retrain_d, retrain_t, start_step=start_step, verbose=self.verbose)
def test_train_zero_steps(self): before = TestTrain.get_state(self.model) train.train(self.hparams.training_hparams, self.model, self.train_loader, self.root, callbacks=[self.callback], end_step=Step.from_iteration(0, len(self.train_loader))) after = TestTrain.get_state(self.model) for k in before: self.assertTrue(np.array_equal(before[k], after[k])) self.assertEqual(self.step_counter, 0) self.assertEqual(self.ep, 0) self.assertEqual(self.it, 0)
def test_train_two_steps(self): before = TestTrain.get_state(self.model) train.train(self.hparams.training_hparams, self.model, self.train_loader, self.root, callbacks=[self.callback], end_step=Step.from_iteration(2, len(self.train_loader))) after = TestTrain.get_state(self.model) for k in before: with self.subTest(k=k): self.assertFalse(np.array_equal(before[k], after[k]), k) self.assertEqual(self.step_counter, 3) self.assertEqual(self.ep, 0) self.assertEqual(self.it, 2) self.assertEqual(self.lr, 0.02)
def test_save_load_exists(self): hp = registry.get_default_hparams('cifar_resnet_20') model = registry.get(hp.model_hparams) step = Step.from_iteration(27, 17) model_location = paths.model(self.root, step) model_state = TestSaveLoadExists.get_state(model) self.assertFalse(registry.exists(self.root, step)) self.assertFalse(os.path.exists(model_location)) # Test saving. model.save(self.root, step) self.assertTrue(registry.exists(self.root, step)) self.assertTrue(os.path.exists(model_location)) # Test loading. model = registry.get(hp.model_hparams) model.load_state_dict(torch.load(model_location)) self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model)) model = registry.load(self.root, step, hp.model_hparams) self.assertStateEqual(model_state, TestSaveLoadExists.get_state(model))
def test_from_iteration(self): self.assertStepEquals(Step.from_iteration(0, 1), 0, 0, 0) self.assertStepEquals(Step.from_iteration(0, 100), 0, 0, 0) self.assertStepEquals(Step.from_iteration(10, 100), 10, 0, 10) self.assertStepEquals(Step.from_iteration(110, 100), 110, 1, 10) self.assertStepEquals(Step.from_iteration(11010, 100), 11010, 110, 10)
def test_overwrite(self): logger = TestMetricLogger.create_logger() logger.add('train_accuracy', Step.from_iteration(0, 400), 1.0) self.assertEqual(logger.get_data('train_accuracy'), [(0, 1.0), (1, 0.6)])
def create_logger(): logger = MetricLogger() logger.add('train_accuracy', Step.from_iteration(0, 400), 0.5) logger.add('train_accuracy', Step.from_iteration(1, 400), 0.6) logger.add('test_accuracy', Step.from_iteration(0, 400), 0.4) return logger