def testNativeLoss(self): runner = TorchRunner(model_creator, single_loader, optimizer_creator, loss_creator=nn.MSELoss) runner.setup() runner.train_epoch()
def testGivens(self): class MockOperator(TrainingOperator): def setup(self, config): self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) self.validate = MagicMock(returns=dict(mean_accuracy=10)) def three_model_creator(config): return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1) def three_optimizer_creator(models, config): opts = [ torch.optim.SGD(model.parameters(), lr=0.1) for model in models ] return opts[0], opts[1], opts[2] runner = TorchRunner(three_model_creator, single_loader, three_optimizer_creator, loss_creator, training_operator_cls=MockOperator) runner.setup() self.assertEqual(len(runner.given_models), 3) self.assertEqual(len(runner.given_optimizers), 3) runner2 = TorchRunner(model_creator, single_loader, optimizer_creator, loss_creator) runner2.setup() self.assertNotEqual(runner2.given_models, runner2.models) self.assertNotEqual(runner2.given_optimizers, runner2.optimizers)
def testSingleLoader(self): runner = TorchRunner(model_creator, single_loader, optimizer_creator, loss_creator) runner.setup() runner.train_epoch() with self.assertRaises(ValueError): runner.validate()
def testMultiLoaders(self): def three_data_loader(config): return (LinearDataset(2, 5), LinearDataset(2, 5, size=400), LinearDataset(2, 5, size=400)) runner = TorchRunner(model_creator, three_data_loader, optimizer_creator, loss_creator) with self.assertRaises(ValueError): runner.setup() runner2 = TorchRunner(model_creator, three_data_loader, optimizer_creator, loss_creator) with self.assertRaises(ValueError): runner2.setup()
def testMultiModel(self): def multi_model_creator(config): return nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1) def multi_optimizer_creator(models, config): opts = [ torch.optim.SGD(model.parameters(), lr=0.1) for model in models ] return opts[0], opts[1], opts[2] runner = TorchRunner(multi_model_creator, single_loader, multi_optimizer_creator, loss_creator) with self.assertRaises(ValueError): runner.setup()
def testValidate(self): class MockOperator(TrainingOperator): def setup(self, config): self.train_epoch = MagicMock(returns=dict(mean_accuracy=10)) self.validate = MagicMock(returns=dict(mean_accuracy=10)) runner = TorchRunner(model_creator, create_dataloaders, optimizer_creator, loss_creator, training_operator_cls=MockOperator) runner.setup() runner.train_epoch() runner.train_epoch() result = runner.train_epoch() self.assertEqual(runner.training_operator.validate.call_count, 0) runner.validate() self.assertTrue(runner.training_operator.validate.called) self.assertEqual(result["epoch"], 3)
def testtrain_epoch(self): class MockOperator(TrainingOperator): def setup(self, config): self.count = 0 def train_epoch(self, *args, **kwargs): self.count += 1 return {"count": self.count} runner = TorchRunner(model_creator, create_dataloaders, optimizer_creator, loss_creator, training_operator_cls=MockOperator) runner.setup() runner.train_epoch(num_steps=1) runner.train_epoch(num_steps=1) result = runner.train_epoch() self.assertEqual(runner.training_operator.count, 3) self.assertEqual(result["count"], 3) self.assertEqual(result["epoch"], 3)