Пример #1
0
    def test_store_load_of_step_with_train_if(self, cloudpickle_mock, open_mock):
        train_if_mock = MagicMock()
        step = Step(self.module_mock, self.step_mock, None, train_if=train_if_mock)
        fm_mock = MagicMock()
        fm_mock.get_path.return_value = os.path.join("folder", "test_train_if.pickle")
        json = step.get_json(fm_mock)
        reloaded_step = Step.load(json, [self.step_mock], targets=None, module=self.module_mock,
                                  file_manager=MagicMock())

        # One call in load and one in save
        open_mock.assert_has_calls(
            [call(os.path.join("folder", "test_train_if.pickle"), "wb"),
             call(os.path.join("folder", "test_train_if.pickle"), "rb")],
            any_order=True)
        self.assertEqual(json, {
            "target_ids": {},
            # Same as for test_load.
            "input_ids": {},
            "id": -1,
            'batch_size': None,
            'default_run_setting': {'computation_mode': 4},
            "train_if": os.path.join("folder", "test_train_if.pickle"),
            "module": "pywatts.core.step",
            "class": "Step",
            "name": "test",
            'callbacks': [],
            "last": True,

            'condition': None}, json),

        self.assertEqual(reloaded_step.module, self.module_mock)
        self.assertEqual(reloaded_step.input_steps, [self.step_mock])
        cloudpickle_mock.load.assert_called_once_with(open_mock().__enter__.return_value)
        cloudpickle_mock.dump.assert_called_once_with(train_if_mock, open_mock().__enter__.return_value)
Пример #2
0
 def test_get_json(self):
     step = Step(self.module_mock, self.step_mock, None)
     json = step.get_json("file")
     self.assertEqual({'batch_size': None,
                       'callbacks': [],
                       'class': 'Step',
                       'default_run_setting': {'computation_mode': 4},
                       'condition': None,
                       'id': -1,
                       'input_ids': {},
                       'last': True,
                       'module': 'pywatts.core.step',
                       'name': 'test',
                       'target_ids': {},
                       'train_if': None}, json)