def test_train(self): model = SimpleModel().train() train_config = TrainConfig([], torch.nn.Module(), torch.optim.SGD(model.parameters(), lr=0.1)) dp = TrainDataProcessor(model=model, train_config=train_config) self.assertFalse(model.fc.weight.is_cuda) self.assertTrue(model.training) res = dp.predict({'data': torch.rand(1, 3)}, is_train=True) self.assertTrue(model.training) self.assertTrue(res.requires_grad) self.assertIsNone(res.grad) with self.assertRaises(NotImplementedError): dp.process_batch( { 'data': torch.rand(1, 3), 'target': torch.rand(1) }, is_train=True) loss = SimpleLoss() train_config = TrainConfig([], loss, torch.optim.SGD(model.parameters(), lr=0.1)) dp = TrainDataProcessor(model=model, train_config=train_config) res = dp.process_batch( { 'data': torch.rand(1, 3), 'target': torch.rand(1) }, is_train=True) self.assertTrue(model.training) self.assertTrue(loss.module.requires_grad) self.assertIsNotNone(loss.module.grad) self.assertTrue(np.array_equal(res, loss.res.data.numpy()))
def test_predict(self): model = SimpleModel().train() train_config = TrainConfig(model, [], torch.nn.Module(), torch.optim.SGD(model.parameters(), lr=0.1)) dp = TrainDataProcessor(train_config=train_config) self.assertFalse(model.fc.weight.is_cuda) self.assertTrue(model.training) res = dp.predict({'data': torch.rand(1, 3)}) self.assertFalse(model.training) self.assertFalse(res.requires_grad) self.assertIsNone(res.grad)
def test_prediction_output(self): model = SimpleModel() train_config = TrainConfig([], torch.nn.Module(), torch.optim.SGD(model.parameters(), lr=0.1)) dp = TrainDataProcessor(model=model, train_config=train_config) self.assertFalse(model.fc.weight.is_cuda) res = dp.predict({'data': torch.rand(1, 3)}, is_train=False) self.assertIs(type(res), torch.Tensor) model = NonStandardIOModel() dp = TrainDataProcessor(model=model, train_config=train_config) self.assertFalse(model.fc.weight.is_cuda) res = dp.predict( {'data': { 'data1': torch.rand(1, 3), 'data2': torch.rand(1, 3) }}, is_train=False) self.assertIs(type(res), dict) self.assertIn('res1', res) self.assertIs(type(res['res1']), torch.Tensor) self.assertIn('res2', res) self.assertIs(type(res['res2']), torch.Tensor)