def setUp(self): self._default = IIterationState(num_time_steps=3, num_states=4)
class IIterationStateTest(NumpyAwareTestCase): def setUp(self): self._default = IIterationState(num_time_steps=3, num_states=4) def test_has_same_features_as_state_iterator(self): self.setUp() is_iterable(self._default, ITimeStepState, 3) self.setUp() has_current_accessor(self._default) self.setUp() has_previous_accessor(self._default) self.setUp() has_next_accessor(self._default) self.setUp() has_first_and_last_accessor(self._default) for _time in self._default: is_iterable(_time, IStepState, 4) self.setUp() for _time in self._default: has_current_accessor(_time) self.setUp() for _time in self._default: has_previous_accessor(_time) self.setUp() for _time in self._default: has_next_accessor(_time) self.setUp() for _time in self._default: has_first_and_last_accessor(_time) def test_on_finalize_collect_solutions_and_finalize(self): _count = 0.0 for _time in self._default: for _step in _time: _step.solution.value = numpy.array([1.0]) _step.solution.time_point = _count _count += 1 self.assertFalse(self._default.solution.finalized) self.assertNumpyArrayEqual(self._default.solution.values, numpy.array([])) self._default.finalize() self.assertTrue(self._default.solution.finalized) self.assertNumpyArrayEqual(self._default.solution.values, numpy.array([[1.0]] * 12)) # .done should reset the current_index counter self.assertIs(self._default.current_index, 0) self.assertRaises(RuntimeError, self._default.finalize) def test_on_proceed_set_initial_step_of_time_step(self): self._default.current_time_step.last_step.solution.value = numpy.array([1.0]) self.assertIsNone(self._default.current_time_step.initial.solution.value) self.assertIsNone(self._default.next_time_step.initial.solution.value) self.assertNotEqual(self._default.current_time_step.initial, self._default.next_time_step.initial) self._default.proceed() self.assertIs(self._default.previous_time_step.last_step, self._default.current_time_step.initial) self.assertNumpyArrayEqual(self._default.current_time_step.initial.solution.value, numpy.array([1.0])) def test_has_proxies_for_steps(self): self._default.proceed() self.assertIs(self._default.current_step, self._default.current_time_step.current_step) self.assertIs(self._default.current_step_index, self._default.current_time_step.current_step_index) self.assertIs(self._default.previous_step, self._default.previous_time_step.last) self.assertIs(self._default.next_step, self._default.current_time_step.next_step) self._default.current_time_step._current_index = self._default.current_time_step.last_step_index self.assertIs(self._default.next_step, self._default.next_time_step.first) self.assertIs(self._default.final_step, self._default.last_time_step.last) self.assertIs(self._default.first_step, self._default.first_time_step.first) def test_has_aliases_for_state_accessors(self): self.assertIs(self._default.first_time_step, self._default.first) self.assertIs(self._default.first_time_step_index, self._default.first_index) self.assertIs(self._default.last_time_step, self._default.last) self.assertIs(self._default.last_time_step_index, self._default.last_index) while self._default.current_index != self._default.last_index: self.assertIs(self._default.current_time_step, self._default.current) self.assertIs(self._default.current_time_step_index, self._default.current_index) self.assertIs(self._default.previous_time_step, self._default.previous) self.assertIs(self._default.previous_time_step_index, self._default.previous_index) self.assertIs(self._default.next_time_step, self._default.next) self.assertIs(self._default.next_time_step_index, self._default.next_index) self._default.proceed()