def testValidate(self): class MockOperator(TrainingOperator): def setup(self, config): self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) self.validate = MagicMock(returns=dict(mean_accuracy=10)) runner = TorchRunner(model_creator, create_dataloaders, optimizer_creator, loss_creator, training_operator_cls=MockOperator) runner.setup() runner.train_epoch() runner.train_epoch() runner.train_epoch() self.assertEqual(runner.training_operator.validate.call_count, 0) runner.validate() self.assertTrue(runner.training_operator.validate.called) self.assertEqual(runner.stats()["epoch"], 3)
def testtrain_epoch(self): class MockOperator(TrainingOperator): def setup(self, config): self.count = 0 def train_epoch(self, *args, **kwargs): self.count += 1 return {"count": self.count} runner = TorchRunner(model_creator, create_dataloaders, optimizer_creator, loss_creator, training_operator_cls=MockOperator) runner.setup() runner.train_epoch(num_steps=1) runner.train_epoch(num_steps=1) result = runner.train_epoch() self.assertEqual(runner.training_operator.count, 3) self.assertEqual(result["count"], 3) self.assertEqual(runner.stats()["epoch"], 3)