def test_predict(self): model = SimpleModel().train() dp = DataProcessor(model=model) self.assertFalse(model.fc.weight.is_cuda) self.assertTrue(model.training) res = dp.predict({'data': torch.rand(1, 3)}) self.assertFalse(model.training) self.assertFalse(res.requires_grad) self.assertIsNone(res.grad)
def test_prediction_output(self): model = SimpleModel() dp = DataProcessor(model=model) self.assertFalse(model.fc.weight.is_cuda) res = dp.predict({'data': torch.rand(1, 3)}) self.assertIs(type(res), torch.Tensor) model = NonStandardIOModel() dp = DataProcessor(model=model) self.assertFalse(model.fc.weight.is_cuda) res = dp.predict({'data': {'data1': torch.rand(1, 3), 'data2': torch.rand(1, 3)}}) self.assertIs(type(res), dict) self.assertIn('res1', res) self.assertIs(type(res['res1']), torch.Tensor) self.assertIn('res2', res) self.assertIs(type(res['res2']), torch.Tensor)
def test_continue_from_checkpoint(self): def on_node(n1, n2): self.assertTrue(np.array_equal(n1.numpy(), n2.numpy())) model = SimpleModel().train() dp = DataProcessor(model=model) with self.assertRaises(Model.ModelException): dp.save_state() try: fsm = FileStructManager(self.base_dir, is_continue=False) dp.set_checkpoints_manager(CheckpointsManager(fsm)) dp.save_state() except: self.fail( 'Fail to DataProcessor load when CheckpointsManager was defined' ) del dp model_new = SimpleModel().train() dp = DataProcessor(model=model_new) with self.assertRaises(Model.ModelException): dp.load() fsm = FileStructManager(base_dir=self.base_dir, is_continue=True) dp.set_checkpoints_manager(CheckpointsManager(fsm)) try: fsm = FileStructManager(self.base_dir, is_continue=True) dp.set_checkpoints_manager(CheckpointsManager(fsm)) dp.load() except: self.fail( 'Fail to DataProcessor load when CheckpointsManager was defined' ) compare_two_models(self, model, model_new)
def test_initialisation(self): try: DataProcessor(model=SimpleModel()) except: self.fail('DataProcessor initialisation raises exception')