def test_predict(self):
     model = SimpleModel().train()
     dp = DataProcessor(model=model)
     self.assertFalse(model.fc.weight.is_cuda)
     self.assertTrue(model.training)
     res = dp.predict({'data': torch.rand(1, 3)})
     self.assertFalse(model.training)
     self.assertFalse(res.requires_grad)
     self.assertIsNone(res.grad)
    def test_prediction_output(self):
        model = SimpleModel()
        dp = DataProcessor(model=model)
        self.assertFalse(model.fc.weight.is_cuda)
        res = dp.predict({'data': torch.rand(1, 3)})
        self.assertIs(type(res), torch.Tensor)

        model = NonStandardIOModel()
        dp = DataProcessor(model=model)
        self.assertFalse(model.fc.weight.is_cuda)
        res = dp.predict({'data': {'data1': torch.rand(1, 3), 'data2': torch.rand(1, 3)}})
        self.assertIs(type(res), dict)
        self.assertIn('res1', res)
        self.assertIs(type(res['res1']), torch.Tensor)
        self.assertIn('res2', res)
        self.assertIs(type(res['res2']), torch.Tensor)
    def test_continue_from_checkpoint(self):
        def on_node(n1, n2):
            self.assertTrue(np.array_equal(n1.numpy(), n2.numpy()))

        model = SimpleModel().train()
        dp = DataProcessor(model=model)
        with self.assertRaises(Model.ModelException):
            dp.save_state()
        try:
            fsm = FileStructManager(self.base_dir, is_continue=False)
            dp.set_checkpoints_manager(CheckpointsManager(fsm))
            dp.save_state()
        except:
            self.fail(
                'Fail to DataProcessor load when CheckpointsManager was defined'
            )

        del dp

        model_new = SimpleModel().train()
        dp = DataProcessor(model=model_new)

        with self.assertRaises(Model.ModelException):
            dp.load()
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=True)
        dp.set_checkpoints_manager(CheckpointsManager(fsm))
        try:
            fsm = FileStructManager(self.base_dir, is_continue=True)
            dp.set_checkpoints_manager(CheckpointsManager(fsm))
            dp.load()
        except:
            self.fail(
                'Fail to DataProcessor load when CheckpointsManager was defined'
            )

        compare_two_models(self, model, model_new)
 def test_initialisation(self):
     try:
         DataProcessor(model=SimpleModel())
     except:
         self.fail('DataProcessor initialisation raises exception')