Ejemplo n.º 1
0
    def test_store_load_of_step_with_train_if(self, cloudpickle_mock, open_mock):
        train_if_mock = MagicMock()
        step = Step(self.module_mock, self.step_mock, None, train_if=train_if_mock)
        fm_mock = MagicMock()
        fm_mock.get_path.return_value = os.path.join("folder", "test_train_if.pickle")
        json = step.get_json(fm_mock)
        reloaded_step = Step.load(json, [self.step_mock], targets=None, module=self.module_mock,
                                  file_manager=MagicMock())

        # One call in load and one in save
        open_mock.assert_has_calls(
            [call(os.path.join("folder", "test_train_if.pickle"), "wb"),
             call(os.path.join("folder", "test_train_if.pickle"), "rb")],
            any_order=True)
        self.assertEqual(json, {
            "target_ids": {},
            # Same as for test_load.
            "input_ids": {},
            "id": -1,
            'batch_size': None,
            'default_run_setting': {'computation_mode': 4},
            "train_if": os.path.join("folder", "test_train_if.pickle"),
            "module": "pywatts.core.step",
            "class": "Step",
            "name": "test",
            'callbacks': [],
            "last": True,

            'condition': None}, json),

        self.assertEqual(reloaded_step.module, self.module_mock)
        self.assertEqual(reloaded_step.input_steps, [self.step_mock])
        cloudpickle_mock.load.assert_called_once_with(open_mock().__enter__.return_value)
        cloudpickle_mock.dump.assert_called_once_with(train_if_mock, open_mock().__enter__.return_value)
Ejemplo n.º 2
0
    def test_load(self):
        step_config = {
            'batch_size': None,
            "target_ids": {},
            "input_ids": {2: 'x'},
            'default_run_setting': {'computation_mode': 3},
            "id": -1,
            "module": "pywatts.core.step",
            "class": "Step",
            "condition": None,
            "train_if": None,
            'callbacks': [],
            "name": "test",
            "last": False,
        }
        step = Step.load(step_config, {'x': self.step_mock}, None, self.module_mock, None)

        self.assertEqual(step.name, "test")
        self.assertEqual(step.id, -1)
        self.assertEqual(step.get_json("file"), step_config)
        self.assertEqual(step.module, self.module_mock)