def test_train(self):
        model = SimpleModel().train()
        train_config = TrainConfig([], torch.nn.Module(),
                                   torch.optim.SGD(model.parameters(), lr=0.1))
        dp = TrainDataProcessor(model=model, train_config=train_config)

        self.assertFalse(model.fc.weight.is_cuda)
        self.assertTrue(model.training)
        res = dp.predict({'data': torch.rand(1, 3)}, is_train=True)
        self.assertTrue(model.training)
        self.assertTrue(res.requires_grad)
        self.assertIsNone(res.grad)

        with self.assertRaises(NotImplementedError):
            dp.process_batch(
                {
                    'data': torch.rand(1, 3),
                    'target': torch.rand(1)
                },
                is_train=True)

        loss = SimpleLoss()
        train_config = TrainConfig([], loss,
                                   torch.optim.SGD(model.parameters(), lr=0.1))
        dp = TrainDataProcessor(model=model, train_config=train_config)
        res = dp.process_batch(
            {
                'data': torch.rand(1, 3),
                'target': torch.rand(1)
            }, is_train=True)
        self.assertTrue(model.training)
        self.assertTrue(loss.module.requires_grad)
        self.assertIsNotNone(loss.module.grad)
        self.assertTrue(np.array_equal(res, loss.res.data.numpy()))
    def test_continue_from_checkpoint(self):
        def on_node(n1, n2):
            self.assertTrue(np.array_equal(n1.numpy(), n2.numpy()))

        model = SimpleModel().train()
        loss = SimpleLoss()

        for optim in [
                torch.optim.SGD(model.parameters(), lr=0.1),
                torch.optim.Adam(model.parameters(), lr=0.1)
        ]:
            train_config = TrainConfig([], loss, optim)

            dp_before = TrainDataProcessor(model=model,
                                           train_config=train_config)
            before_state_dict = model.state_dict().copy()
            dp_before.update_lr(0.023)

            with self.assertRaises(Model.ModelException):
                dp_before.save_state()
            try:
                fsm = FileStructManager(base_dir=self.base_dir,
                                        is_continue=False)
                dp_before.set_checkpoints_manager(CheckpointsManager(fsm))
                dp_before.save_state()
            except:
                self.fail(
                    "Exception on saving state when 'CheckpointsManager' specified"
                )

            fsm = FileStructManager(base_dir=self.base_dir, is_continue=True)
            dp_after = TrainDataProcessor(model=model,
                                          train_config=train_config)
            with self.assertRaises(Model.ModelException):
                dp_after.load()
            try:
                cm = CheckpointsManager(fsm)
                dp_after.set_checkpoints_manager(cm)
                dp_after.load()
            except:
                self.fail('DataProcessor initialisation raises exception')

            after_state_dict = model.state_dict().copy()

            dict_pair_recursive_bypass(before_state_dict, after_state_dict,
                                       on_node)
            self.assertEqual(dp_before.get_lr(), dp_after.get_lr())

            shutil.rmtree(self.base_dir)
Ejemplo n.º 3
0
 def test_initialisation(self):
     model = SimpleModel()
     train_config = TrainConfig(model, [], torch.nn.Module(), torch.optim.SGD(model.parameters(), lr=0.1))
     try:
         TrainDataProcessor(train_config=train_config)
     except:
         self.fail('DataProcessor initialisation raises exception')
Ejemplo n.º 4
0
 def test_predict(self):
     model = SimpleModel().train()
     train_config = TrainConfig(model, [], torch.nn.Module(), torch.optim.SGD(model.parameters(), lr=0.1))
     dp = TrainDataProcessor(train_config=train_config)
     self.assertFalse(model.fc.weight.is_cuda)
     self.assertTrue(model.training)
     res = dp.predict({'data': torch.rand(1, 3)})
     self.assertFalse(model.training)
     self.assertFalse(res.requires_grad)
     self.assertIsNone(res.grad)
    def test_prediction_output(self):
        model = SimpleModel()
        train_config = TrainConfig([], torch.nn.Module(),
                                   torch.optim.SGD(model.parameters(), lr=0.1))
        dp = TrainDataProcessor(model=model, train_config=train_config)
        self.assertFalse(model.fc.weight.is_cuda)
        res = dp.predict({'data': torch.rand(1, 3)}, is_train=False)
        self.assertIs(type(res), torch.Tensor)

        model = NonStandardIOModel()
        dp = TrainDataProcessor(model=model, train_config=train_config)
        self.assertFalse(model.fc.weight.is_cuda)
        res = dp.predict(
            {'data': {
                'data1': torch.rand(1, 3),
                'data2': torch.rand(1, 3)
            }},
            is_train=False)
        self.assertIs(type(res), dict)
        self.assertIn('res1', res)
        self.assertIs(type(res['res1']), torch.Tensor)
        self.assertIn('res2', res)
        self.assertIs(type(res['res2']), torch.Tensor)
Ejemplo n.º 6
0
    def __init__(self, train_config: TrainConfig, fsm: FileStructManager, device: torch.device = None):
        self._fsm = fsm
        self.monitor_hub = MonitorHub()

        self._checkpoint_manager = CheckpointsManager(self._fsm)

        self.__epoch_num = 100
        self._resume_from = None
        self._on_epoch_end = []
        self._best_state_rule = None

        self._train_config = train_config
        self._data_processor = TrainDataProcessor(self._train_config, device).set_checkpoints_manager(self._checkpoint_manager)
        self._lr = LearningRate(self._data_processor.get_lr())

        self._stop_rules = []