def testGivens(self): def three_model_creator(config): return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1) def three_optimizer_creator(models, config): opts = [ torch.optim.SGD(model.parameters(), lr=0.1) for model in models ] return opts[0], opts[1], opts[2] class MockOperator(TrainingOperator): def setup(self, config): models = three_model_creator(config) optimizers = three_optimizer_creator(models, config) loader = single_loader(config) loss = loss_creator(config) self.models, self.optimizers, self.criterion = self.register( models=models, optimizers=optimizers, criterion=loss) self.register_data(train_loader=loader, validation_loader=None) self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) self.validate = MagicMock(returns=dict(mean_accuracy=10)) runner = TorchRunner(training_operator_cls=MockOperator) runner.setup_operator() self.assertEqual(len(runner.given_models), 3) self.assertEqual(len(runner.given_optimizers), 3) runner2 = TorchRunner(training_operator_cls=self.Operator) runner2.setup_operator() self.assertNotEqual(runner2.given_models, runner2.models) self.assertNotEqual(runner2.given_optimizers, runner2.optimizers)
def testNativeLoss(self): NativeOperator = TrainingOperator.from_creators( model_creator, optimizer_creator, single_loader, loss_creator=nn.MSELoss) runner = TorchRunner(training_operator_cls=NativeOperator) runner.setup_operator() runner.train_epoch()
def testSingleLoader(self): SingleOperator = TrainingOperator.from_creators( model_creator, optimizer_creator, single_loader, loss_creator=loss_creator) runner = TorchRunner(training_operator_cls=SingleOperator) runner.setup_operator() runner.train_epoch() with self.assertRaises(ValueError): runner.validate()
def testValidate(self): class MockOperator(self.Operator): def setup(self, config): super(MockOperator, self).setup(config) self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) self.validate = MagicMock(returns=dict(mean_accuracy=10)) runner = TorchRunner(training_operator_cls=MockOperator) runner.setup_operator() runner.train_epoch() runner.train_epoch() result = runner.train_epoch() self.assertEqual(runner.training_operator.validate.call_count, 0) runner.validate() self.assertTrue(runner.training_operator.validate.called) self.assertEqual(result["epoch"], 3)
def testtrain_epoch(self): class MockOperator(self.Operator): def setup(self, config): super(MockOperator, self).setup(config) self.count = 0 def train_epoch(self, *args, **kwargs): self.count += 1 return {"count": self.count} runner = TorchRunner(training_operator_cls=MockOperator) runner.setup_operator() runner.train_epoch(num_steps=1) runner.train_epoch(num_steps=1) result = runner.train_epoch() self.assertEqual(runner.training_operator.count, 3) self.assertEqual(result["count"], 3) self.assertEqual(result["epoch"], 3)
def testMultiLoaders(self): def three_data_loader(config): return (LinearDataset(2, 5), LinearDataset(2, 5, size=400), LinearDataset(2, 5, size=400)) ThreeOperator = TrainingOperator.from_creators( model_creator, optimizer_creator, three_data_loader, loss_creator=loss_creator) runner = TorchRunner(training_operator_cls=ThreeOperator) with self.assertRaises(ValueError): runner.setup_operator() runner2 = TorchRunner(training_operator_cls=ThreeOperator) with self.assertRaises(ValueError): runner2.setup_operator()