def testNativeLoss(self): runner = TorchRunner(model_creator, single_loader, optimizer_creator, loss_creator=nn.MSELoss) runner.setup() runner.train_epoch()
def testSingleLoader(self): runner = TorchRunner(model_creator, single_loader, optimizer_creator, loss_creator) runner.setup() runner.train_epoch() with self.assertRaises(ValueError): runner.validate()
def testNativeLoss(self): NativeOperator = TrainingOperator.from_creators( model_creator, optimizer_creator, single_loader, loss_creator=nn.MSELoss) runner = TorchRunner(training_operator_cls=NativeOperator) runner.setup_operator() runner.train_epoch()
def testSingleLoader(self): SingleOperator = TrainingOperator.from_creators( model_creator, optimizer_creator, single_loader, loss_creator=loss_creator) runner = TorchRunner(training_operator_cls=SingleOperator) runner.setup_operator() runner.train_epoch() with self.assertRaises(ValueError): runner.validate()
def testValidate(self): class MockOperator(self.Operator): def setup(self, config): super(MockOperator, self).setup(config) self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) self.validate = MagicMock(returns=dict(mean_accuracy=10)) runner = TorchRunner(training_operator_cls=MockOperator) runner.setup_operator() runner.train_epoch() runner.train_epoch() result = runner.train_epoch() self.assertEqual(runner.training_operator.validate.call_count, 0) runner.validate() self.assertTrue(runner.training_operator.validate.called) self.assertEqual(result["epoch"], 3)
def testtrain_epoch(self): class MockOperator(self.Operator): def setup(self, config): super(MockOperator, self).setup(config) self.count = 0 def train_epoch(self, *args, **kwargs): self.count += 1 return {"count": self.count} runner = TorchRunner(training_operator_cls=MockOperator) runner.setup_operator() runner.train_epoch(num_steps=1) runner.train_epoch(num_steps=1) result = runner.train_epoch() self.assertEqual(runner.training_operator.count, 3) self.assertEqual(result["count"], 3) self.assertEqual(result["epoch"], 3)
def testValidate(self): class MockOperator(TrainingOperator): def setup(self, config): self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) self.validate = MagicMock(returns=dict(mean_accuracy=10)) runner = TorchRunner(model_creator, create_dataloaders, optimizer_creator, loss_creator, training_operator_cls=MockOperator) runner.setup() runner.train_epoch() runner.train_epoch() result = runner.train_epoch() self.assertEqual(runner.training_operator.validate.call_count, 0) runner.validate() self.assertTrue(runner.training_operator.validate.called) self.assertEqual(result["epoch"], 3)
def testtrain_epoch(self): class MockOperator(TrainingOperator): def setup(self, config): self.count = 0 def train_epoch(self, *args, **kwargs): self.count += 1 return {"count": self.count} runner = TorchRunner(model_creator, create_dataloaders, optimizer_creator, loss_creator, training_operator_cls=MockOperator) runner.setup() runner.train_epoch(num_steps=1) runner.train_epoch(num_steps=1) result = runner.train_epoch() self.assertEqual(runner.training_operator.count, 3) self.assertEqual(result["count"], 3) self.assertEqual(result["epoch"], 3)